diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index 7be5fc7f47d5db027d120b8024982df93db95b74..1bec17fb5cb12bf1430547a921d460a8a1ae544b 100644
--- a/README.md
+++ b/README.md
@@ -1,3 +1,252 @@
----
-license: mit
----
+
[OTO–HNS2024] A Deep Learning Framework for Analysis of the Eustachian Tube and the Internal Carotid Artery
+
+
+
+
+
+Ameen Amanian, Aseem Jain, Yuliang Xiao, Chanha Kim, Andy S. Ding, Manish Sahu, Russell Taylor, Mathias Unberath, Bryan K. Ward, Deepa Galaiya, Masaru Ishii, Francis X. Creighton
+
+ News |
+ Abstract |
+ Installation |
+ Train |
+ Inference |
+ Evaluation
+
+
+## News
+
+**2024.04.30** - The data preprocessing , training, inference, and evaluation code are released.
+
+**2024.04.05** - Our paper is accepted to **American Academy of Otolaryngology–Head and Neck Surgery 2024 (OTO-HNS2024)**.
+
+## Abstract
+- Objective: Obtaining automated, objective 3-dimensional (3D)
+models of the Eustachian tube (ET) and the internal carotid
+artery (ICA) from computed tomography (CT) scans could
+provide useful navigational and diagnostic information for ET
+pathologies and interventions. We aim to develop a deep
+learning (DL) pipeline to automatically segment the ET and
+ICA and use these segmentations to compute distances
+between these structures.
+
+- Methods: From a database of 30 CT scans, 60 ET and ICA pairs
+were manually segmented and used to train an nnU-Net model,
+a DL segmentation framework. These segmentations were also
+used to develop a quantitative tool to capture the magnitude
+and location of the minimum distance point (MDP) between ET
+and ICA. Performance metrics for the nnU-Net automated
+segmentations were calculated via the average Hausdorff
+distance (AHD) and dice similarity coefficient (DSC).
+
+- Results: The AHD for the ETand ICA were 0.922 and 0.246 mm,
+respectively. Similarly, the DSC values for the ET and ICA were
+0.578 and 0.884. The mean MDP from ET to ICA in the
+cartilaginous region was 2.6 mm (0.7-5.3 mm) and was located
+on average 1.9 mm caudal from the bony cartilaginous junction.
+
+- Conclusion: This study describes the first end-to-end DL
+pipeline for automated ET and ICA segmentation and analyzes
+distances between these structures. In addition to helping to
+ensure the safe selection of patients for ET dilation, this
+method can facilitate large-scale studies exploring the
+relationship between ET pathologies and the 3D shape of
+the ET.
+
+
+
+ Figure 1: Overview of Workflow
+
+
+
+## Installation
+
+### Step 1: Fork This GitHub Repository
+
+```bash
+git clone https://github.com/mikami520/AutoSeg4ETICA.git && cd AutoSeg4ETICA
+```
+
+### Step 2: Set Up Two Environments Using requirements.txt Files (virtual environment is recommended)
+
+```bash
+pip install -r requirements.txt
+source /path/to/VIRTUAL_ENVIRONMENT/bin/activate
+```
+
+## Preprocessing
+
+### Step 1: Register Data to Template
+
+```bash
+cd /preprocessing
+```
+
+Register data to template (can be used for multiple segmentations propagation)
+
+```bash
+python registration.py -bp -ip -sp
+```
+
+If you want to make sure correspondence of the name and value of segmentations, you can add the following commands after above command
+
+```bash
+-sl LabelValue1 LabelName1 LabelValue2 LabelName2 LabelValue3 LabelName3 ...
+```
+
+For example, if I have two labels for maxillary sinus named L-MS and R-MS
+
+```bash
+python registration.py -bp /Users/mikamixiao/Desktop -ip images -sp labels -sl 1 L-MS 2 R-MS
+```
+
+Final output of registered images and segmentations will be saved in
+
+```text
+imagesRS/ && labelsRS/
+```
+
+### Step 2: Create Datasplit for Training/Testing. Validation will be chosen automatically by nnUNet (filename format should be taskname_xxx.nii.gz)
+
+```bash
+python split_data.py -bp -ip -sp -sl -ti -tn
+```
+
+For example
+
+```bash
+python split_data.py -bp /Users/mikamixiao/Desktop -ip imagesRS -sp labelsRS -sl 1 L-MS 2 R-MS -ti 001 -tn Sinus
+```
+
+### Step 3: Setup Bashrc
+
+Edit your `~/.bashrc` file with `gedit ~/.bashrc` or `nano ~/.bashrc`. At the end of the file, add the following lines:
+
+```bash
+export nnUNet_raw_data_base="/nnUnet/nnUNet_raw_data_base"
+export nnUNet_preprocessed="/nnUNet_preprocessed"
+export RESULTS_FOLDER="/nnUnet/nnUNet_trained_models"
+```
+
+After updating this you will need to source your `~/.bashrc` file.
+
+```bash
+source ~/.bashrc
+```
+
+This will deactivate your current conda environment.
+
+### Step 4: Verify and Preprocess Data
+
+Activate nnUNet environment
+
+```bash
+source /path/to/VIRTUAL_ENVIRONMENT/bin/activate
+```
+
+Run nnUNet preprocessing script.
+
+```bash
+nnUNet_plan_and_preprocess -t --verify_dataset_integrity
+```
+
+Potential Error: You may need to edit the dataset.json file so that the labels are sequential. If you have at least 10 labels, then labels `10, 11, 12,...` will be arranged before labels `2, 3, 4, ...`. Doing this in a text editor is completely fine!
+
+## Train
+
+To train the model:
+
+```bash
+nnUNet_train 3d_fullres nnUNetTrainerV2 Task_TemporalBone Y --npz
+```
+
+`Y` refers to the number of folds for cross-validation. If `Y` is set to `all` then all of the data will be used for training. If you want to try 5-folds cross validation, you should define Y as `0, 1, 2, 3, 4 ` for five times.
+
+`--npz` makes the models save the softmax outputs (uncompressed, large files) during the final validation. It should only be used if you are training multiple configurations, which requires `nnUNet_find_best_configuration` to find the best model. We omit this by default.
+
+## Inference
+
+To run inference on trained checkpoints and obtain evaluation results:
+`nnUNet_find_best_configuration` will print a string to the terminal with the inference commands you need to use.
+The easiest way to run inference is to simply use these commands.
+
+If you wish to manually specify the configuration(s) used for inference, use the following commands:
+
+For each of the desired configurations, run:
+
+```bash
+nnUNet_predict -i INPUT_FOLDER -o OUTPUT_FOLDER -t TASK_NAME_OR_ID -m CONFIGURATION --save_npz
+```
+
+Only specify `--save_npz` if you intend to use ensembling. `--save_npz` will make the command save the softmax
+probabilities alongside of the predicted segmentation masks requiring a lot of disk space.
+
+Please select a separate `OUTPUT_FOLDER` for each configuration!
+
+If you wish to run ensembling, you can ensemble the predictions from several configurations with the following command:
+
+```bash
+nnUNet_ensemble -f FOLDER1 FOLDER2 ... -o OUTPUT_FOLDER -pp POSTPROCESSING_FILE
+```
+
+You can specify an arbitrary number of folders, but remember that each folder needs to contain npz files that were
+generated by `nnUNet_predict`. For ensembling you can also specify a file that tells the command how to postprocess.
+These files are created when running `nnUNet_find_best_configuration` and are located in the respective trained model directory `(RESULTS_FOLDER/nnUNet/CONFIGURATION/TaskXXX_MYTASK/TRAINER_CLASS_NAME__PLANS_FILE_IDENTIFIER/postprocessing.json or RESULTS_FOLDER/nnUNet/ensembles/TaskXXX_MYTASK/ensemble_X__Y__Z--X__Y__Z/postprocessing.json)`. You can also choose to not provide a file (simply omit -pp) and nnU-Net will not run postprocessing.
+
+Note that per default, inference will be done with all available folds. We very strongly recommend you use all 5 folds.
+Thus, all 5 folds must have been trained prior to running inference. The list of available folds nnU-Net found will be
+printed at the start of the inference.
+
+## Evaluation
+
+To compute the dice score, average hausdorff distance and weighted hausdorff distance:
+
+```bash
+cd /metrics
+```
+
+Run the metrics.py to output a CSV file that contain the dice score and hausdorff distance for each segmentation:
+
+```bash
+python metrics.py -bp -gp -pp -sp -vt
+```
+
+Users can choose any combinations of evaluation types among these three choices.
+
+```text
+dsc: Dice Score
+ahd: Average Hausdorff Distance
+whd: Weighted Hausdorff Distance
+```
+
+If choosing ```whd``` and you do not have a probability map, you can use ```get_probability_map.py```to obtain one. Here is the way to use:
+
+```bash
+python get_probability_map.py -bp -pp -rr -ps
+```
+
+Currently, we split the skeleton alongside the x axis and from ear end to nasal. Please make sure the probability sequences are matched to the splitted regions. The output probability map which is a text file will be stored in ```output/```under the ```base directory```. Once obtaining the probability map, you can import your customized probability map by adding following command when using ```metrics.py```:
+
+```bash
+-pm
+```
+
+#### To draw the heat map to see the failing part of prediction:
+
+```bash
+python distanceVertex2Mesh.py -bp -gp -pp
+```
+
+Once you get the closest distance (save in ```output/``` under ```base directory```) from prediction to ground truth, you can easily draw the heat map and use the color bar to show the change of differences (```ParaView``` is recommended)
+
+## Citing Paper
+
+If you find this paper helpful, please consider citing:
+```bibtex
+@article{amanian2024deep,
+ title={A Deep Learning Framework for Analysis of the Eustachian Tube and the Internal Carotid Artery},
+ author={Amanian, Ameen and Jain, Aseem and Xiao, Yuliang and Kim, Chanha and Ding, Andy S and Sahu, Manish and Taylor, Russell and Unberath, Mathias and Ward, Bryan K and Galaiya, Deepa and others},
+ journal={Otolaryngology--Head and Neck Surgery},
+ publisher={Wiley Online Library}
+}
+```
diff --git a/assets/method.png b/assets/method.png
new file mode 100644
index 0000000000000000000000000000000000000000..fba7cb1a08004593ba6c038d1533b917fa56083e
Binary files /dev/null and b/assets/method.png differ
diff --git a/metrics/distanceVertex2Mesh.py b/metrics/distanceVertex2Mesh.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1e09dd48117e414b1d1fe8c54fc25cd8c09f04c
--- /dev/null
+++ b/metrics/distanceVertex2Mesh.py
@@ -0,0 +1,64 @@
+import numpy as np
+import pyvista as pv
+import argparse
+import os
+import glob
+import trimesh
+
+
+def parse_command_line():
+ print('---'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(description='Defacing protocol')
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ parser.add_argument('-gp', metavar='ground truth path', type=str,
+ help="Relative path of the ground truth model")
+ parser.add_argument('-pp', metavar='prediction path', type=str,
+ help="Relative path of the prediction model")
+ argv = parser.parse_args()
+ return argv
+
+
+def distanceVertex2Mesh(mesh, vertex):
+ faces_as_array = mesh.faces.reshape((mesh.n_faces, 4))[:, 1:]
+ mesh_box = trimesh.Trimesh(vertices=mesh.points,
+ faces=faces_as_array)
+ cp, cd, ci = trimesh.proximity.closest_point(mesh_box, vertex)
+ return cd
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ gt_path = args.gp
+ pred_path = args.pp
+ output_dir = os.path.join(base, 'output')
+ try:
+ os.mkdir(output_dir)
+ except:
+ print(f'{output_dir} already exists')
+
+ for i in glob.glob(os.path.join(base, gt_path) + '/*.vtk'):
+ filename = os.path.basename(i).split('.')[0]
+ #side = os.path.basename(i).split('.')[0].split('_')[0]
+ #scan_name = os.path.basename(i).split('.')[0].split('_')[0]
+ #scan_id = os.path.basename(i).split('.')[0].split('_')[1]
+ output_sub_dir = os.path.join(
+ base, 'output', filename)
+ try:
+ os.mkdir(output_sub_dir)
+ except:
+ print(f'{output_sub_dir} already exists')
+
+ gt_mesh = pv.read(i)
+ pred_mesh = pv.read(os.path.join(
+ base, pred_path, filename + '.vtk'))
+ pred_vertices = np.array(pred_mesh.points)
+ cd = distanceVertex2Mesh(gt_mesh, pred_vertices)
+ pred_mesh['dist'] = cd
+ pred_mesh.save(os.path.join(output_sub_dir, filename + '.vtk'))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/metrics/get_probability_map.py b/metrics/get_probability_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..569cf6b8f76b134a9cc0a67a1cbfd762b7cdc996
--- /dev/null
+++ b/metrics/get_probability_map.py
@@ -0,0 +1,194 @@
+import numpy as np
+import pyvista as pv
+import argparse
+import os
+import glob
+import skeletor as sk
+import trimesh
+import navis
+
+
+def parse_command_line():
+ print('---'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(description='Defacing protocol')
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ parser.add_argument('-gp', metavar='ground truth path', type=str,
+ help="Relative path of the ground truth model")
+ parser.add_argument('-pp', metavar='prediction path', type=str,
+ help="Relative path of the prediction model")
+ parser.add_argument('-rr', metavar='ratio to split skeleton', type=int, nargs='+',
+ help="Ratio to split the skeleton")
+ parser.add_argument('-ps', metavar='probability sequences', type=float, nargs='+',
+ help="Proability sequences for each splitted region")
+ argv = parser.parse_args()
+ return argv
+
+
+def distanceVertex2Path(mesh, skeleton, probability_map):
+ if len(probability_map) == 0:
+ print('empty probability_map !!!')
+ return np.inf
+
+ if not mesh.is_all_triangles():
+ print('only triangulations is allowed (Faces do not have 3 Vertices)!')
+ return np.inf
+
+ if hasattr(mesh, 'points'):
+ points = np.array(mesh.points)
+ else:
+ print('mesh structure must contain fields ''vertices'' and ''faces''!')
+ return np.inf
+
+ if hasattr(skeleton, 'vertices'):
+ vertex = skeleton.vertices
+ else:
+ print('skeleton structure must contain fields ''vertices'' !!!')
+ return np.inf
+
+ numV, dim = points.shape
+ numT, dimT = vertex.shape
+
+ if dim != dimT or dim != 3:
+ print('mesh and vertices must be in 3D space!')
+ return np.inf
+
+ d_min = np.ones(numV, dtype=np.float64) * np.inf
+ pm = []
+ # first check: find closest distance from vertex to vertex
+ for i in range(numV):
+ min_idx = -1
+ for j in range(numT):
+ v1 = points[i, :]
+ v2 = vertex[j, :]
+ d = distance3DV2V(v1, v2)
+ if d < d_min[i]:
+ d_min[i] = d
+ min_idx = j
+
+ pm.append(probability_map[min_idx])
+
+ print("check is finished !!!")
+ return pm
+
+
+def generate_probability_map(skeleton, split_ratio, probability):
+ points = skeleton.vertices
+ center = skeleton.skeleton.centroid
+ x = sorted(points[:, 0])
+ left = []
+ right = []
+ for i in range(len(x)):
+ if x[i] < center[0]:
+ left.append(x[i])
+ else:
+ right.append(x[i])
+
+ right_map = []
+ left_map = []
+ sec_old = 0
+ for j in range(len(split_ratio)):
+ if j == len(split_ratio) - 1:
+ sec_len = len(left) - sec_old
+ else:
+ sec_len = int(round(len(left) * split_ratio[j] / 100))
+
+ for k in range(sec_old, sec_old + sec_len):
+ left_map.append(probability[j])
+
+ sec_old += sec_len
+
+ sec_old = 0
+ for j in range(len(split_ratio)-1, -1, -1):
+ if j == 0:
+ sec_len = len(right) - sec_old
+ else:
+ sec_len = int(round(len(right) * split_ratio[j] / 100))
+
+ for k in range(sec_old, sec_old + sec_len):
+ right_map.append(probability[j])
+
+ sec_old += sec_len
+
+ final_map = []
+ row = points.shape[0]
+ assert len(left) + len(right) == row
+ for m in range(row):
+ ver_x = points[m, 0]
+ if ver_x in left:
+ index = left.index(ver_x)
+ final_map.append(left_map[index])
+ else:
+ index = right.index(ver_x)
+ final_map.append(right_map[index])
+
+ return final_map
+
+
+def skeleton(mesh):
+ faces_as_array = mesh.faces.reshape((mesh.n_faces, 4))[:, 1:]
+ trmesh = trimesh.Trimesh(mesh.points, faces_as_array)
+ fixed = sk.pre.fix_mesh(trmesh, remove_disconnected=5, inplace=False)
+ skel = sk.skeletonize.by_wavefront(fixed, waves=1, step_size=1)
+ # Create a neuron from your skeleton
+ n = navis.TreeNeuron(skel, soma=None)
+ # keep only the two longest linear section in your skeleton
+ long2 = navis.longest_neurite(n, n=2, from_root=False)
+
+ # This renumbers nodes
+ swc = navis.io.swc_io.make_swc_table(long2)
+ # We also need to rename some columns
+ swc = swc.rename({'PointNo': 'node_id', 'Parent': 'parent_id', 'X': 'x',
+ 'Y': 'y', 'Z': 'z', 'Radius': 'radius'}, axis=1).drop('Label', axis=1)
+ # Skeletor excepts node IDs to start with 0, but navis starts at 1 for SWC
+ swc['node_id'] -= 1
+ swc.loc[swc.parent_id > 0, 'parent_id'] -= 1
+ # Create the skeletor.Skeleton
+ skel2 = sk.Skeleton(swc)
+ return skel2
+
+
+def distance3DV2V(v1, v2):
+ d = np.linalg.norm(v1-v2)
+ return d
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ gt_path = args.gp
+ pred_path = args.pp
+ area_ratio = args.rr
+ prob_sequences = args.ps
+ output_dir = os.path.join(base, 'output')
+ try:
+ os.mkdir(output_dir)
+ except:
+ print(f'{output_dir} already exists')
+
+ for i in glob.glob(os.path.join(base, gt_path) + '/*.vtk'):
+ scan_name = os.path.basename(i).split('.')[0].split('_')[1]
+ scan_id = os.path.basename(i).split('.')[0].split('_')[2]
+ output_sub_dir = os.path.join(
+ base, 'output', scan_name + '_' + scan_id)
+ try:
+ os.mkdir(output_sub_dir)
+ except:
+ print(f'{output_sub_dir} already exists')
+
+ gt_mesh = pv.read(i)
+ pred_mesh = pv.read(os.path.join(
+ base, pred_path, 'pred_' + scan_name + '_' + scan_id + '.vtk'))
+ pred_skel = skeleton(pred_mesh)
+ prob_map = generate_probability_map(
+ pred_skel, area_ratio, prob_sequences)
+ pm = distanceVertex2Path(pred_mesh, pred_skel, prob_map)
+ if(pm == np.Inf):
+ print('something with mesh, probability map and skeleton are wrong !!!')
+ return
+ np.savetxt(os.path.join(base, output_sub_dir, scan_id + '.txt'), pm)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/metrics/lookup_tables.py b/metrics/lookup_tables.py
new file mode 100644
index 0000000000000000000000000000000000000000..30eff0c167d8018868c5b6cca206405f73ca2e2d
--- /dev/null
+++ b/metrics/lookup_tables.py
@@ -0,0 +1,463 @@
+# Copyright 2018 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS-IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import numpy as np
+ENCODE_NEIGHBOURHOOD_3D_KERNEL = np.array([[[128, 64], [32, 16]], [[8, 4],
+ [2, 1]]])
+
+"""
+
+lookup_tables.py
+
+all of the lookup-tables functions are borrowed from DeepMind surface_distance repository
+
+"""
+
+
+# _NEIGHBOUR_CODE_TO_NORMALS is a lookup table.
+# For every binary neighbour code
+# (2x2x2 neighbourhood = 8 neighbours = 8 bits = 256 codes)
+# it contains the surface normals of the triangles (called "surfel" for
+# "surface element" in the following). The length of the normal
+# vector encodes the surfel area.
+#
+# created using the marching_cube algorithm
+# see e.g. https://en.wikipedia.org/wiki/Marching_cubes
+# pylint: disable=line-too-long
+_NEIGHBOUR_CODE_TO_NORMALS = [
+ [[0, 0, 0]],
+ [[0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
+ [[0.125, -0.125, 0.125]],
+ [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
+ [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
+ [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],
+ [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],
+ [[0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25],
+ [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],
+ [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375],
+ [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
+ [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25],
+ [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
+ [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125],
+ [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
+ [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],
+ [[0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
+ [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],
+ [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
+ [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25],
+ [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25],
+ [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25],
+ [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
+ [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375],
+ [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
+ [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
+ [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],
+ [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],
+ [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],
+ [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25],
+ [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
+ [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0],
+ [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
+ [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
+ [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
+ [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0],
+ [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
+ [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0],
+ [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
+ [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
+ [[-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],
+ [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25],
+ [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
+ [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
+ [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25],
+ [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
+ [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375],
+ [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
+ [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
+ [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
+ [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375],
+ [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
+ [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],
+ [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125],
+ [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25],
+ [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
+ [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375],
+ [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
+ [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125],
+ [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25],
+ [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
+ [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125],
+ [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],
+ [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125],
+ [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25],
+ [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
+ [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0],
+ [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
+ [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
+ [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25],
+ [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25],
+ [0.125, 0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
+ [[0.125, 0.125, 0.125], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
+ [[-0.125, 0.125, 0.125], [-0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[-0.375, -0.375, 0.375], [0.25, -0.25, 0.0],
+ [0.0, 0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.0, -0.5, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[0.375, -0.375, 0.375], [0.0, 0.25, 0.25],
+ [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
+ [[-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125],
+ [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[0.5, 0.0, -0.0], [0.25, -0.25, -0.25], [0.125, -0.125, -0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125],
+ [-0.125, -0.125, 0.125], [0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25], [0.0, 0.25, 0.25]],
+ [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125],
+ [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[0.125, 0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.25, 0.0, 0.25], [-0.375, -0.375, 0.375],
+ [-0.25, 0.25, 0.0], [-0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25],
+ [0.25, 0.0, 0.25], [0.25, 0.0, 0.25]],
+ [[-0.0, 0.0, 0.5], [0.25, -0.25, 0.25],
+ [0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25], [-0.125, 0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[0.125, 0.125, 0.125], [0.125, 0.125, 0.125],
+ [0.25, 0.25, 0.25], [0.0, 0.0, 0.5]],
+ [[-0.0, 0.0, 0.5], [0.0, 0.0, 0.5]],
+ [[0.0, 0.0, -0.5], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[-0.25, -0.0, -0.25], [-0.375, 0.375, 0.375],
+ [-0.25, -0.25, 0.0], [-0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
+ [[-0.0, 0.0, 0.5], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [0.25, 0.0, -0.25]],
+ [[0.5, 0.0, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[-0.25, 0.0, -0.25], [0.375, -0.375, -0.375],
+ [0.0, 0.25, -0.25], [-0.125, 0.125, 0.125]],
+ [[-0.25, 0.25, -0.25], [-0.25, 0.25, -0.25],
+ [-0.125, 0.125, -0.125], [-0.125, 0.125, -0.125]],
+ [[-0.0, 0.5, 0.0], [-0.25, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[0.375, -0.375, 0.375], [0.0, -0.25, -0.25],
+ [-0.125, 0.125, -0.125], [0.25, 0.25, 0.0]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.0, 0.0, 0.5], [0.25, -0.25, 0.25], [0.125, -0.125, 0.125]],
+ [[0.0, -0.25, 0.25], [0.0, -0.25, 0.25]],
+ [[-0.125, -0.125, 0.125], [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
+ [[-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
+ [[0.125, 0.125, 0.125], [-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0]],
+ [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0], [-0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [-0.25, -0.25, 0.0],
+ [-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
+ [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25], [-0.125, -0.125, 0.125]],
+ [[-0.375, 0.375, -0.375], [-0.25, -0.25, 0.0],
+ [-0.125, 0.125, -0.125], [-0.25, 0.0, 0.25]],
+ [[0.0, 0.5, 0.0], [0.25, 0.25, -0.25],
+ [-0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
+ [[0.0, -0.5, 0.0], [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
+ [[0.125, 0.125, 0.125], [0.0, -0.5, 0.0],
+ [-0.25, -0.25, -0.25], [-0.125, -0.125, -0.125]],
+ [[-0.375, -0.375, -0.375], [-0.25, 0.0, 0.25],
+ [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
+ [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0], [0.125, -0.125, 0.125]],
+ [[0.0, 0.5, 0.0], [0.0, -0.5, 0.0]],
+ [[0.0, 0.5, 0.0], [0.125, -0.125, 0.125], [-0.25, 0.25, -0.25]],
+ [[0.0, 0.5, 0.0], [-0.25, 0.25, 0.25], [0.125, -0.125, -0.125]],
+ [[0.25, -0.25, 0.0], [-0.25, 0.25, 0.0]],
+ [[-0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.0, 0.25, -0.25], [0.375, -0.375, -0.375],
+ [-0.125, 0.125, 0.125], [0.25, 0.25, 0.0]],
+ [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25],
+ [-0.125, 0.125, -0.125], [0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.25, 0.25, -0.25], [0.25, 0.25, -0.25],
+ [0.125, 0.125, -0.125], [-0.125, -0.125, 0.125]],
+ [[-0.0, 0.0, 0.5], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[-0.375, -0.375, 0.375], [-0.0, 0.25, 0.25],
+ [0.125, 0.125, -0.125], [-0.25, -0.0, -0.25]],
+ [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25], [0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
+ [[0.125, -0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.0, -0.5, 0.0], [0.125, 0.125, -0.125], [0.25, 0.25, -0.25]],
+ [[0.0, -0.25, 0.25], [0.0, 0.25, -0.25]],
+ [[0.125, 0.125, 0.125], [0.125, -0.125, 0.125]],
+ [[0.125, -0.125, 0.125]],
+ [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25]],
+ [[-0.5, 0.0, 0.0], [-0.125, -0.125, -0.125],
+ [-0.25, -0.25, -0.25], [0.125, 0.125, 0.125]],
+ [[0.375, 0.375, 0.375], [0.0, 0.25, -0.25],
+ [-0.125, -0.125, -0.125], [-0.25, 0.25, 0.0]],
+ [[0.125, -0.125, -0.125], [0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.125, 0.125, 0.125], [0.375, 0.375, 0.375],
+ [0.0, -0.25, 0.25], [-0.25, 0.0, 0.25]],
+ [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[-0.125, -0.125, -0.125], [-0.25, -0.25, -0.25],
+ [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, 0.0, -0.5], [0.25, 0.25, 0.25], [-0.125, -0.125, -0.125]],
+ [[0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.5, 0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[-0.125, -0.125, 0.125], [0.125, -0.125, -0.125]],
+ [[0.0, -0.25, -0.25], [0.0, 0.25, 0.25]],
+ [[0.125, -0.125, -0.125]],
+ [[0.5, 0.0, 0.0], [0.5, 0.0, 0.0]],
+ [[-0.5, 0.0, 0.0], [-0.25, 0.25, 0.25], [-0.125, 0.125, 0.125]],
+ [[0.5, 0.0, 0.0], [0.25, -0.25, 0.25], [-0.125, 0.125, -0.125]],
+ [[0.25, -0.25, 0.0], [0.25, -0.25, 0.0]],
+ [[0.5, 0.0, 0.0], [-0.25, -0.25, 0.25], [-0.125, -0.125, 0.125]],
+ [[-0.25, 0.0, 0.25], [-0.25, 0.0, 0.25]],
+ [[0.125, 0.125, 0.125], [-0.125, 0.125, 0.125]],
+ [[-0.125, 0.125, 0.125]],
+ [[0.5, 0.0, -0.0], [0.25, 0.25, 0.25], [0.125, 0.125, 0.125]],
+ [[0.125, -0.125, 0.125], [-0.125, -0.125, 0.125]],
+ [[-0.25, -0.0, -0.25], [0.25, 0.0, 0.25]],
+ [[0.125, -0.125, 0.125]],
+ [[-0.25, -0.25, 0.0], [0.25, 0.25, -0.0]],
+ [[-0.125, -0.125, 0.125]],
+ [[0.125, 0.125, 0.125]],
+ [[0, 0, 0]]]
+# pylint: enable=line-too-long
+
+
+def create_table_neighbour_code_to_surface_area(spacing_mm):
+ """Returns an array mapping neighbourhood code to the surface elements area.
+ Note that the normals encode the initial surface area. This function computes
+ the area corresponding to the given `spacing_mm`.
+ Args:
+ spacing_mm: 3-element list-like structure. Voxel spacing in x0, x1 and x2
+ direction.
+ """
+ # compute the area for all 256 possible surface elements
+ # (given a 2x2x2 neighbourhood) according to the spacing_mm
+ neighbour_code_to_surface_area = np.zeros([256])
+ for code in range(256):
+ normals = np.array(_NEIGHBOUR_CODE_TO_NORMALS[code])
+ sum_area = 0
+ for normal_idx in range(normals.shape[0]):
+ # normal vector
+ n = np.zeros([3])
+ n[0] = normals[normal_idx, 0] * spacing_mm[1] * spacing_mm[2]
+ n[1] = normals[normal_idx, 1] * spacing_mm[0] * spacing_mm[2]
+ n[2] = normals[normal_idx, 2] * spacing_mm[0] * spacing_mm[1]
+ area = np.linalg.norm(n)
+ sum_area += area
+ neighbour_code_to_surface_area[code] = sum_area
+
+ return neighbour_code_to_surface_area
+
+
+# In the neighbourhood, points are ordered: top left, top right, bottom left,
+# bottom right.
+ENCODE_NEIGHBOURHOOD_2D_KERNEL = np.array([[8, 4], [2, 1]])
+
+
+def create_table_neighbour_code_to_contour_length(spacing_mm):
+ """Returns an array mapping neighbourhood code to the contour length.
+ For the list of possible cases and their figures, see page 38 from:
+ https://nccastaff.bournemouth.ac.uk/jmacey/MastersProjects/MSc14/06/thesis.pdf
+ In 2D, each point has 4 neighbors. Thus, are 16 configurations. A
+ configuration is encoded with '1' meaning "inside the object" and '0' "outside
+ the object". The points are ordered: top left, top right, bottom left, bottom
+ right.
+ The x0 axis is assumed vertical downward, and the x1 axis is horizontal to the
+ right:
+ (0, 0) --> (0, 1)
+ |
+ (1, 0)
+ Args:
+ spacing_mm: 2-element list-like structure. Voxel spacing in x0 and x1
+ directions.
+ """
+ neighbour_code_to_contour_length = np.zeros([16])
+
+ vertical = spacing_mm[0]
+ horizontal = spacing_mm[1]
+ diag = 0.5 * math.sqrt(spacing_mm[0]**2 + spacing_mm[1]**2)
+ # pyformat: disable
+ neighbour_code_to_contour_length[int("00"
+ "01", 2)] = diag
+
+ neighbour_code_to_contour_length[int("00"
+ "10", 2)] = diag
+
+ neighbour_code_to_contour_length[int("00"
+ "11", 2)] = horizontal
+
+ neighbour_code_to_contour_length[int("01"
+ "00", 2)] = diag
+
+ neighbour_code_to_contour_length[int("01"
+ "01", 2)] = vertical
+
+ neighbour_code_to_contour_length[int("01"
+ "10", 2)] = 2*diag
+
+ neighbour_code_to_contour_length[int("01"
+ "11", 2)] = diag
+
+ neighbour_code_to_contour_length[int("10"
+ "00", 2)] = diag
+
+ neighbour_code_to_contour_length[int("10"
+ "01", 2)] = 2*diag
+
+ neighbour_code_to_contour_length[int("10"
+ "10", 2)] = vertical
+
+ neighbour_code_to_contour_length[int("10"
+ "11", 2)] = diag
+
+ neighbour_code_to_contour_length[int("11"
+ "00", 2)] = horizontal
+
+ neighbour_code_to_contour_length[int("11"
+ "01", 2)] = diag
+
+ neighbour_code_to_contour_length[int("11"
+ "10", 2)] = diag
+ # pyformat: enable
+
+ return neighbour_code_to_contour_length
\ No newline at end of file
diff --git a/metrics/metrics.py b/metrics/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..047a733aa42dde8d0864a0330d3907d51cc03a04
--- /dev/null
+++ b/metrics/metrics.py
@@ -0,0 +1,355 @@
+import numpy as np
+import nibabel as nib
+import ants
+import argparse
+import pandas as pd
+import glob
+import os
+import surface_distance
+import nrrd
+import shutil
+import distanceVertex2Mesh
+import textwrap
+
+
+def parse_command_line():
+ print('---'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(
+ description='Inference evaluation pipeline for image registration-segmentation', formatter_class=argparse.RawTextHelpFormatter)
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ parser.add_argument('-gp', metavar='ground truth path', type=str,
+ help="Relative path of the ground truth segmentation directory")
+ parser.add_argument('-pp', metavar='predicted path', type=str,
+ help="Relative path of predicted segmentation directory")
+ parser.add_argument('-sp', metavar='save path', type=str,
+ help="Relative path of CSV file directory to save, if not specify, default is base directory")
+ parser.add_argument('-vt', metavar='validation type', type=str, nargs='+',
+ help=textwrap.dedent('''Validation type:
+ dsc: Dice Score
+ ahd: Average Hausdorff Distance
+ whd: Weighted Hausdorff Distance
+ '''))
+ parser.add_argument('-pm', metavar='probability map path', type=str,
+ help="Relative path of text file directory of probability map")
+ parser.add_argument('-fn', metavar='file name', type=str,
+ help="name of output file")
+ parser.add_argument('-reg', action='store_true',
+ help="check if the input files are registration predictions")
+ parser.add_argument('-tp', metavar='type of segmentation', type=str,
+ help=textwrap.dedent('''Segmentation type:
+ ET: Eustachian Tube
+ NC: Nasal Cavity
+ HT: Head Tumor
+ '''))
+ parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
+ help='a list of label name and corresponding value')
+ parser.add_argument('-cp', metavar='current prefix of filenames', type=str,
+ help='current prefix of filenames')
+ argv = parser.parse_args()
+ return argv
+
+
+def rename(prefix, filename):
+ name = filename.split('.')[0][-3:]
+ name = prefix + '_' + name
+ return name
+
+def dice_coefficient_and_hausdorff_distance(filename, img_np_pred, img_np_gt, num_classes, spacing, probability_map, dsc, ahd, whd, average_DSC, average_HD):
+ df = pd.DataFrame()
+ data_gt, bool_gt = make_one_hot(img_np_gt, num_classes)
+ data_pred, bool_pred = make_one_hot(img_np_pred, num_classes)
+ for i in range(1, num_classes):
+ df1 = pd.DataFrame([[filename, i]], columns=[
+ 'File ID', 'Label Value'])
+ if dsc:
+ if data_pred[i].any():
+ volume_sum = data_gt[i].sum() + data_pred[i].sum()
+ if volume_sum == 0:
+ return np.NaN
+
+ volume_intersect = (data_gt[i] & data_pred[i]).sum()
+ dice = 2*volume_intersect / volume_sum
+ df1['Dice Score'] = dice
+ average_DSC[i-1] += dice
+ else:
+ dice = 0.0
+ df1['Dice Score'] = dice
+ average_DSC[i-1] += dice
+ if ahd:
+ if data_pred[i].any():
+ avd = average_hausdorff_distance(bool_gt[i], bool_pred[i], spacing)
+ df1['Average Hausdorff Distance'] = avd
+ average_HD[i-1] += avd
+ else:
+ avd = np.nan
+ df1['Average Hausdorff Distance'] = avd
+ average_HD[i-1] += avd
+ if whd:
+ # wgd = weighted_hausdorff_distance(gt, pred, probability_map)
+ # df1['Weighted Hausdorff Distance'] = wgd
+ pass
+
+ df = pd.concat([df, df1])
+ return df, average_DSC, average_HD
+
+
+def make_one_hot(img_np, num_classes):
+ img_one_hot_dice = np.zeros(
+ (num_classes, img_np.shape[0], img_np.shape[1], img_np.shape[2]), dtype=np.int8)
+ img_one_hot_hd = np.zeros(
+ (num_classes, img_np.shape[0], img_np.shape[1], img_np.shape[2]), dtype=bool)
+ for i in range(num_classes):
+ a = (img_np == i)
+ img_one_hot_dice[i, :, :, :] = a
+ img_one_hot_hd[i, :, :, :] = a
+
+ return img_one_hot_dice, img_one_hot_hd
+
+
+def average_hausdorff_distance(img_np_gt, img_np_pred, spacing):
+ surf_distance = surface_distance.compute_surface_distances(
+ img_np_gt, img_np_pred, spacing)
+ gp, pg = surface_distance.compute_average_surface_distance(surf_distance)
+ return (gp + pg) / 2
+
+
+def checkSegFormat(base, segmentation, type, prefix=None):
+ if type == 'gt':
+ save_dir = os.path.join(base, 'gt_reformat_labels')
+ path = segmentation
+ else:
+ save_dir = os.path.join(base, 'pred_reformat_labels')
+ path = os.path.join(base, segmentation)
+ try:
+ os.mkdir(save_dir)
+ except:
+ print(f'{save_dir} already exists')
+
+ for file in os.listdir(path):
+ if type == 'gt':
+ if prefix is not None:
+ name = rename(prefix, file)
+ else:
+ name = file.split('.')[0]
+ else:
+ name = file.split('.')[0]
+
+ if file.endswith('seg.nrrd'):
+ ants_img = ants.image_read(os.path.join(path, file))
+ header = nrrd.read_header(os.path.join(path, file))
+ filename = os.path.join(save_dir, name + '.nii.gz')
+ nrrd2nifti(ants_img, header, filename)
+ elif file.endswith('nii'):
+ image = ants.image_read(os.path.join(path, file))
+ image.to_file(os.path.join(save_dir, name + '.nii.gz'))
+ elif file.endswith('nii.gz'):
+ shutil.copy(os.path.join(path, file), os.path.join(save_dir, name + '.nii.gz'))
+
+ return save_dir
+
+
+def nrrd2nifti(img, header, filename):
+ img_as_np = img.view(single_components=True)
+ data = convert_to_one_hot(img_as_np, header)
+ foreground = np.max(data, axis=0)
+ labelmap = np.multiply(np.argmax(data, axis=0) + 1,
+ foreground).astype('uint8')
+ segmentation_img = ants.from_numpy(
+ labelmap, origin=img.origin, spacing=img.spacing, direction=img.direction)
+ print('-- Saving NII Segmentations')
+ segmentation_img.to_file(filename)
+
+
+def convert_to_one_hot(data, header, segment_indices=None):
+ print('---'*10)
+ print("converting to one hot")
+
+ layer_values = get_layer_values(header)
+ label_values = get_label_values(header)
+
+ # Newer Slicer NRRD (compressed layers)
+ if layer_values and label_values:
+
+ assert len(layer_values) == len(label_values)
+ if len(data.shape) == 3:
+ x_dim, y_dim, z_dim = data.shape
+ elif len(data.shape) == 4:
+ x_dim, y_dim, z_dim = data.shape[1:]
+
+ num_segments = len(layer_values)
+ one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
+
+ if segment_indices is None:
+ segment_indices = list(range(num_segments))
+
+ elif isinstance(segment_indices, int):
+ segment_indices = [segment_indices]
+
+ elif not isinstance(segment_indices, list):
+ print("incorrectly specified segment indices")
+ return
+
+ # Check if NRRD is composed of one layer 0
+ if np.max(layer_values) == 0:
+ for i, seg_idx in enumerate(segment_indices):
+ layer = layer_values[seg_idx]
+ label = label_values[seg_idx]
+ one_hot[i] = 1*(data == label).astype(np.uint8)
+
+ else:
+ for i, seg_idx in enumerate(segment_indices):
+ layer = layer_values[seg_idx]
+ label = label_values[seg_idx]
+ one_hot[i] = 1*(data[layer] == label).astype(np.uint8)
+
+ # Binary labelmap
+ elif len(data.shape) == 3:
+ x_dim, y_dim, z_dim = data.shape
+ num_segments = np.max(data)
+ one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
+
+ if segment_indices is None:
+ segment_indices = list(range(1, num_segments + 1))
+
+ elif isinstance(segment_indices, int):
+ segment_indices = [segment_indices]
+
+ elif not isinstance(segment_indices, list):
+ print("incorrectly specified segment indices")
+ return
+
+ for i, seg_idx in enumerate(segment_indices):
+ one_hot[i] = 1*(data == seg_idx).astype(np.uint8)
+
+ # Older Slicer NRRD (already one-hot)
+ else:
+ return data
+
+ return one_hot
+
+
+def get_layer_values(header):
+ layer_values = []
+ num_segments = len([key for key in header.keys() if "Layer" in key])
+ for i in range(num_segments):
+ layer_values.append(int(header['Segment{}_Layer'.format(i)]))
+ return layer_values
+
+
+def get_label_values(header):
+ label_values = []
+ num_segments = len([key for key in header.keys() if "LabelValue" in key])
+ for i in range(num_segments):
+ label_values.append(int(header['Segment{}_LabelValue'.format(i)]))
+ return label_values
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ gt_path = args.gp
+ pred_path = args.pp
+ if args.sp is None:
+ save_path = base
+ else:
+ save_path = args.sp
+ validation_type = args.vt
+ probability_map_path = args.pm
+ filename = args.fn
+ reg = args.reg
+ seg_type = args.tp
+ label_list = args.sl
+ current_prefix = args.cp
+ if probability_map_path is not None:
+ probability_map = np.loadtxt(os.path.join(base, probability_map_path))
+ else:
+ probability_map = None
+ dsc = False
+ ahd = False
+ whd = False
+ for i in range(len(validation_type)):
+ if validation_type[i] == 'dsc':
+ dsc = True
+ elif validation_type[i] == 'ahd':
+ ahd = True
+ elif validation_type[i] == 'whd':
+ whd = True
+ else:
+ print('wrong validation type, please choose correct one !!!')
+ return
+
+ filepath = os.path.join(base, save_path, 'output_' + filename + '.csv')
+ save_dir = os.path.join(base, save_path)
+ gt_output_path = checkSegFormat(base, gt_path, 'gt', current_prefix)
+ pred_output_path = checkSegFormat(base, pred_path, 'pred', current_prefix)
+ try:
+ os.mkdir(save_dir)
+ except:
+ print(f'{save_dir} already exists')
+
+ try:
+ os.mknod(filepath)
+ except:
+ print(f'{filepath} already exists')
+
+ DSC = pd.DataFrame()
+ file = glob.glob(os.path.join(base, gt_output_path) + '/*nii.gz')[0]
+ seg_file = ants.image_read(file)
+ num_class = np.unique(seg_file.numpy().ravel()).shape[0]
+ average_DSC = np.zeros((num_class-1))
+ average_HD = np.zeros((num_class-1))
+ k = 0
+ for i in glob.glob(os.path.join(base, pred_output_path) + '/*nii.gz'):
+ k += 1
+ pred_img = ants.image_read(i)
+ pred_spacing = list(pred_img.spacing)
+ if reg and seg_type == 'ET':
+ file_name = os.path.basename(i).split('.')[0].split('_')[4] + '_' + os.path.basename(
+ i).split('.')[0].split('_')[5] + '_' + os.path.basename(i).split('.')[0].split('_')[6]
+ file_name1 = os.path.basename(i).split('.')[0]
+ elif reg and seg_type == 'NC':
+ file_name = os.path.basename(i).split(
+ '.')[0].split('_')[3] + '_' + os.path.basename(i).split('.')[0].split('_')[4]
+ file_name1 = os.path.basename(i).split('.')[0]
+ elif reg and seg_type == 'HT':
+ file_name = os.path.basename(i).split('.')[0].split('_')[2]
+ file_name1 = os.path.basename(i).split('.')[0]
+ else:
+ file_name = os.path.basename(i).split('.')[0]
+ file_name1 = os.path.basename(i).split('.')[0]
+ gt_seg = os.path.join(base, gt_output_path, file_name + '.nii.gz')
+ gt_img = ants.image_read(gt_seg)
+ gt_spacing = list(gt_img.spacing)
+
+ if gt_spacing != pred_spacing:
+ print(
+ "Spacing of prediction and ground_truth is not matched, please check again !!!")
+ return
+
+ ref = pred_img
+ data_ref = ref.numpy()
+
+ pred = gt_img
+ data_pred = pred.numpy()
+
+ num_class = len(np.unique(data_pred))
+ ds, aver_DSC, aver_HD = dice_coefficient_and_hausdorff_distance(
+ file_name1, data_ref, data_pred, num_class, pred_spacing, probability_map, dsc, ahd, whd, average_DSC, average_HD)
+ DSC = pd.concat([DSC, ds])
+ average_DSC = aver_DSC
+ average_HD = aver_HD
+
+ avg_DSC = average_DSC / k
+ avg_HD = average_HD / k
+ print(avg_DSC)
+ with open(os.path.join(base, save_path, "metric.txt"), 'w') as f:
+ f.write("Label Value Label Name Average Dice Score Average Mean HD\n")
+ for i in range(len(avg_DSC)):
+ f.write(f'{str(i+1):^12}{str(label_list[2*i+1]):^12}{str(avg_DSC[i]):^20}{str(avg_HD[i]):^18}\n')
+ DSC.to_csv(filepath)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/metrics/surface_distance.py b/metrics/surface_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd44a0b6e199a919773bf2672fdc0378d97f2aa6
--- /dev/null
+++ b/metrics/surface_distance.py
@@ -0,0 +1,424 @@
+# Copyright 2018 Google Inc. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS-IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import lookup_tables # pylint: disable=relative-beyond-top-level
+import numpy as np
+from scipy import ndimage
+
+"""
+
+surface_distance.py
+
+all of the surface_distance functions are borrowed from DeepMind surface_distance repository
+
+"""
+def _assert_is_numpy_array(name, array):
+ """Raises an exception if `array` is not a numpy array."""
+ if not isinstance(array, np.ndarray):
+ raise ValueError("The argument {!r} should be a numpy array, not a "
+ "{}".format(name, type(array)))
+
+
+def _check_nd_numpy_array(name, array, num_dims):
+ """Raises an exception if `array` is not a `num_dims`-D numpy array."""
+ if len(array.shape) != num_dims:
+ raise ValueError("The argument {!r} should be a {}D array, not of "
+ "shape {}".format(name, num_dims, array.shape))
+
+
+def _check_2d_numpy_array(name, array):
+ _check_nd_numpy_array(name, array, num_dims=2)
+
+
+def _check_3d_numpy_array(name, array):
+ _check_nd_numpy_array(name, array, num_dims=3)
+
+
+def _assert_is_bool_numpy_array(name, array):
+ _assert_is_numpy_array(name, array)
+ if array.dtype != np.bool:
+ raise ValueError("The argument {!r} should be a numpy array of type bool, "
+ "not {}".format(name, array.dtype))
+
+
+def _compute_bounding_box(mask):
+ """Computes the bounding box of the masks.
+ This function generalizes to arbitrary number of dimensions great or equal
+ to 1.
+ Args:
+ mask: The 2D or 3D numpy mask, where '0' means background and non-zero means
+ foreground.
+ Returns:
+ A tuple:
+ - The coordinates of the first point of the bounding box (smallest on all
+ axes), or `None` if the mask contains only zeros.
+ - The coordinates of the second point of the bounding box (greatest on all
+ axes), or `None` if the mask contains only zeros.
+ """
+ num_dims = len(mask.shape)
+ bbox_min = np.zeros(num_dims, np.int64)
+ bbox_max = np.zeros(num_dims, np.int64)
+
+ # max projection to the x0-axis
+ proj_0 = np.amax(mask, axis=tuple(range(num_dims))[1:])
+ idx_nonzero_0 = np.nonzero(proj_0)[0]
+ if len(idx_nonzero_0) == 0: # pylint: disable=g-explicit-length-test
+ return None, None
+
+ bbox_min[0] = np.min(idx_nonzero_0)
+ bbox_max[0] = np.max(idx_nonzero_0)
+
+ # max projection to the i-th-axis for i in {1, ..., num_dims - 1}
+ for axis in range(1, num_dims):
+ max_over_axes = list(range(num_dims)) # Python 3 compatible
+ max_over_axes.pop(axis) # Remove the i-th dimension from the max
+ max_over_axes = tuple(max_over_axes) # numpy expects a tuple of ints
+ proj = np.amax(mask, axis=max_over_axes)
+ idx_nonzero = np.nonzero(proj)[0]
+ bbox_min[axis] = np.min(idx_nonzero)
+ bbox_max[axis] = np.max(idx_nonzero)
+
+ return bbox_min, bbox_max
+
+
+def _crop_to_bounding_box(mask, bbox_min, bbox_max):
+ """Crops a 2D or 3D mask to the bounding box specified by `bbox_{min,max}`."""
+ # we need to zeropad the cropped region with 1 voxel at the lower,
+ # the right (and the back on 3D) sides. This is required to obtain the
+ # "full" convolution result with the 2x2 (or 2x2x2 in 3D) kernel.
+ # TODO: This is correct only if the object is interior to the
+ # bounding box.
+ cropmask = np.zeros((bbox_max - bbox_min) + 2, np.uint8)
+
+ num_dims = len(mask.shape)
+ # pyformat: disable
+ if num_dims == 2:
+ cropmask[0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,
+ bbox_min[1]:bbox_max[1] + 1]
+ elif num_dims == 3:
+ cropmask[0:-1, 0:-1, 0:-1] = mask[bbox_min[0]:bbox_max[0] + 1,
+ bbox_min[1]:bbox_max[1] + 1,
+ bbox_min[2]:bbox_max[2] + 1]
+ # pyformat: enable
+ else:
+ assert False
+
+ return cropmask
+
+
+def _sort_distances_surfels(distances, surfel_areas):
+ """Sorts the two list with respect to the tuple of (distance, surfel_area).
+ Args:
+ distances: The distances from A to B (e.g. `distances_gt_to_pred`).
+ surfel_areas: The surfel areas for A (e.g. `surfel_areas_gt`).
+ Returns:
+ A tuple of the sorted (distances, surfel_areas).
+ """
+ sorted_surfels = np.array(sorted(zip(distances, surfel_areas)))
+ return sorted_surfels[:, 0], sorted_surfels[:, 1]
+
+
+def compute_surface_distances(mask_gt,
+ mask_pred,
+ spacing_mm):
+ """Computes closest distances from all surface points to the other surface.
+ This function can be applied to 2D or 3D tensors. For 2D, both masks must be
+ 2D and `spacing_mm` must be a 2-element list. For 3D, both masks must be 3D
+ and `spacing_mm` must be a 3-element list. The description is done for the 2D
+ case, and the formulation for the 3D case is present is parenthesis,
+ introduced by "resp.".
+ Finds all contour elements (resp surface elements "surfels" in 3D) in the
+ ground truth mask `mask_gt` and the predicted mask `mask_pred`, computes their
+ length in mm (resp. area in mm^2) and the distance to the closest point on the
+ other contour (resp. surface). It returns two sorted lists of distances
+ together with the corresponding contour lengths (resp. surfel areas). If one
+ of the masks is empty, the corresponding lists are empty and all distances in
+ the other list are `inf`.
+ Args:
+ mask_gt: 2-dim (resp. 3-dim) bool Numpy array. The ground truth mask.
+ mask_pred: 2-dim (resp. 3-dim) bool Numpy array. The predicted mask.
+ spacing_mm: 2-element (resp. 3-element) list-like structure. Voxel spacing
+ in x0 anx x1 (resp. x0, x1 and x2) directions.
+ Returns:
+ A dict with:
+ "distances_gt_to_pred": 1-dim numpy array of type float. The distances in mm
+ from all ground truth surface elements to the predicted surface,
+ sorted from smallest to largest.
+ "distances_pred_to_gt": 1-dim numpy array of type float. The distances in mm
+ from all predicted surface elements to the ground truth surface,
+ sorted from smallest to largest.
+ "surfel_areas_gt": 1-dim numpy array of type float. The length of the
+ of the ground truth contours in mm (resp. the surface elements area in
+ mm^2) in the same order as distances_gt_to_pred.
+ "surfel_areas_pred": 1-dim numpy array of type float. The length of the
+ of the predicted contours in mm (resp. the surface elements area in
+ mm^2) in the same order as distances_gt_to_pred.
+ Raises:
+ ValueError: If the masks and the `spacing_mm` arguments are of incompatible
+ shape or type. Or if the masks are not 2D or 3D.
+ """
+ # The terms used in this function are for the 3D case. In particular, surface
+ # in 2D stands for contours in 3D. The surface elements in 3D correspond to
+ # the line elements in 2D.
+
+ _assert_is_bool_numpy_array("mask_gt", mask_gt)
+ _assert_is_bool_numpy_array("mask_pred", mask_pred)
+
+ if not len(mask_gt.shape) == len(mask_pred.shape) == len(spacing_mm):
+ raise ValueError("The arguments must be of compatible shape. Got mask_gt "
+ "with {} dimensions ({}) and mask_pred with {} dimensions "
+ "({}), while the spacing_mm was {} elements.".format(
+ len(mask_gt.shape),
+ mask_gt.shape, len(
+ mask_pred.shape), mask_pred.shape,
+ len(spacing_mm)))
+
+ num_dims = len(spacing_mm)
+ if num_dims == 2:
+ _check_2d_numpy_array("mask_gt", mask_gt)
+ _check_2d_numpy_array("mask_pred", mask_pred)
+
+ # compute the area for all 16 possible surface elements
+ # (given a 2x2 neighbourhood) according to the spacing_mm
+ neighbour_code_to_surface_area = (
+ lookup_tables.create_table_neighbour_code_to_contour_length(spacing_mm))
+ kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_2D_KERNEL
+ full_true_neighbours = 0b1111
+ elif num_dims == 3:
+ _check_3d_numpy_array("mask_gt", mask_gt)
+ _check_3d_numpy_array("mask_pred", mask_pred)
+
+ # compute the area for all 256 possible surface elements
+ # (given a 2x2x2 neighbourhood) according to the spacing_mm
+ neighbour_code_to_surface_area = (
+ lookup_tables.create_table_neighbour_code_to_surface_area(spacing_mm))
+ kernel = lookup_tables.ENCODE_NEIGHBOURHOOD_3D_KERNEL
+ full_true_neighbours = 0b11111111
+ else:
+ raise ValueError("Only 2D and 3D masks are supported, not "
+ "{}D.".format(num_dims))
+
+ # compute the bounding box of the masks to trim the volume to the smallest
+ # possible processing subvolume
+ bbox_min, bbox_max = _compute_bounding_box(mask_gt | mask_pred)
+ # Both the min/max bbox are None at the same time, so we only check one.
+ if bbox_min is None:
+ return {
+ "distances_gt_to_pred": np.array([]),
+ "distances_pred_to_gt": np.array([]),
+ "surfel_areas_gt": np.array([]),
+ "surfel_areas_pred": np.array([]),
+ }
+
+ # crop the processing subvolume.
+ cropmask_gt = _crop_to_bounding_box(mask_gt, bbox_min, bbox_max)
+ cropmask_pred = _crop_to_bounding_box(mask_pred, bbox_min, bbox_max)
+
+ # compute the neighbour code (local binary pattern) for each voxel
+ # the resulting arrays are spacially shifted by minus half a voxel in each
+ # axis.
+ # i.e. the points are located at the corners of the original voxels
+ neighbour_code_map_gt = ndimage.filters.correlate(
+ cropmask_gt.astype(np.uint8), kernel, mode="constant", cval=0)
+ neighbour_code_map_pred = ndimage.filters.correlate(
+ cropmask_pred.astype(np.uint8), kernel, mode="constant", cval=0)
+
+ # create masks with the surface voxels
+ borders_gt = ((neighbour_code_map_gt != 0) &
+ (neighbour_code_map_gt != full_true_neighbours))
+ borders_pred = ((neighbour_code_map_pred != 0) &
+ (neighbour_code_map_pred != full_true_neighbours))
+
+ # compute the distance transform (closest distance of each voxel to the
+ # surface voxels)
+ if borders_gt.any():
+ distmap_gt = ndimage.morphology.distance_transform_edt(
+ ~borders_gt, sampling=spacing_mm)
+ else:
+ distmap_gt = np.Inf * np.ones(borders_gt.shape)
+
+ if borders_pred.any():
+ distmap_pred = ndimage.morphology.distance_transform_edt(
+ ~borders_pred, sampling=spacing_mm)
+ else:
+ distmap_pred = np.Inf * np.ones(borders_pred.shape)
+
+ # compute the area of each surface element
+ surface_area_map_gt = neighbour_code_to_surface_area[neighbour_code_map_gt]
+ surface_area_map_pred = neighbour_code_to_surface_area[
+ neighbour_code_map_pred]
+
+ # create a list of all surface elements with distance and area
+ distances_gt_to_pred = distmap_pred[borders_gt]
+ distances_pred_to_gt = distmap_gt[borders_pred]
+ surfel_areas_gt = surface_area_map_gt[borders_gt]
+ surfel_areas_pred = surface_area_map_pred[borders_pred]
+
+ # sort them by distance
+ if distances_gt_to_pred.shape != (0,):
+ distances_gt_to_pred, surfel_areas_gt = _sort_distances_surfels(
+ distances_gt_to_pred, surfel_areas_gt)
+
+ if distances_pred_to_gt.shape != (0,):
+ distances_pred_to_gt, surfel_areas_pred = _sort_distances_surfels(
+ distances_pred_to_gt, surfel_areas_pred)
+
+ return {
+ "distances_gt_to_pred": distances_gt_to_pred,
+ "distances_pred_to_gt": distances_pred_to_gt,
+ "surfel_areas_gt": surfel_areas_gt,
+ "surfel_areas_pred": surfel_areas_pred,
+ }
+
+
+def compute_average_surface_distance(surface_distances):
+ """Returns the average surface distance.
+ Computes the average surface distances by correctly taking the area of each
+ surface element into account. Call compute_surface_distances(...) before, to
+ obtain the `surface_distances` dict.
+ Args:
+ surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
+ "surfel_areas_gt", "surfel_areas_pred" created by
+ compute_surface_distances()
+ Returns:
+ A tuple with two float values:
+ - the average distance (in mm) from the ground truth surface to the
+ predicted surface
+ - the average distance from the predicted surface to the ground truth
+ surface.
+ """
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
+ average_distance_gt_to_pred = (
+ np.sum(distances_gt_to_pred * surfel_areas_gt) / np.sum(surfel_areas_gt))
+ average_distance_pred_to_gt = (
+ np.sum(distances_pred_to_gt * surfel_areas_pred) /
+ np.sum(surfel_areas_pred))
+ return (average_distance_gt_to_pred, average_distance_pred_to_gt)
+
+
+def compute_robust_hausdorff(surface_distances, percent):
+ """Computes the robust Hausdorff distance.
+ Computes the robust Hausdorff distance. "Robust", because it uses the
+ `percent` percentile of the distances instead of the maximum distance. The
+ percentage is computed by correctly taking the area of each surface element
+ into account.
+ Args:
+ surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
+ "surfel_areas_gt", "surfel_areas_pred" created by
+ compute_surface_distances()
+ percent: a float value between 0 and 100.
+ Returns:
+ a float value. The robust Hausdorff distance in mm.
+ """
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
+ if len(distances_gt_to_pred) > 0: # pylint: disable=g-explicit-length-test
+ surfel_areas_cum_gt = np.cumsum(
+ surfel_areas_gt) / np.sum(surfel_areas_gt)
+ idx = np.searchsorted(surfel_areas_cum_gt, percent/100.0)
+ perc_distance_gt_to_pred = distances_gt_to_pred[
+ min(idx, len(distances_gt_to_pred)-1)]
+ else:
+ perc_distance_gt_to_pred = np.Inf
+
+ if len(distances_pred_to_gt) > 0: # pylint: disable=g-explicit-length-test
+ surfel_areas_cum_pred = (np.cumsum(surfel_areas_pred) /
+ np.sum(surfel_areas_pred))
+ idx = np.searchsorted(surfel_areas_cum_pred, percent/100.0)
+ perc_distance_pred_to_gt = distances_pred_to_gt[
+ min(idx, len(distances_pred_to_gt)-1)]
+ else:
+ perc_distance_pred_to_gt = np.Inf
+
+ return max(perc_distance_gt_to_pred, perc_distance_pred_to_gt)
+
+
+def compute_surface_overlap_at_tolerance(surface_distances, tolerance_mm):
+ """Computes the overlap of the surfaces at a specified tolerance.
+ Computes the overlap of the ground truth surface with the predicted surface
+ and vice versa allowing a specified tolerance (maximum surface-to-surface
+ distance that is regarded as overlapping). The overlapping fraction is
+ computed by correctly taking the area of each surface element into account.
+ Args:
+ surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
+ "surfel_areas_gt", "surfel_areas_pred" created by
+ compute_surface_distances()
+ tolerance_mm: a float value. The tolerance in mm
+ Returns:
+ A tuple of two float values. The overlap fraction in [0.0, 1.0] of the
+ ground truth surface with the predicted surface and vice versa.
+ """
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
+ rel_overlap_gt = (
+ np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm]) /
+ np.sum(surfel_areas_gt))
+ rel_overlap_pred = (
+ np.sum(surfel_areas_pred[distances_pred_to_gt <= tolerance_mm]) /
+ np.sum(surfel_areas_pred))
+ return (rel_overlap_gt, rel_overlap_pred)
+
+
+def compute_surface_dice_at_tolerance(surface_distances, tolerance_mm):
+ """Computes the _surface_ DICE coefficient at a specified tolerance.
+ Computes the _surface_ DICE coefficient at a specified tolerance. Not to be
+ confused with the standard _volumetric_ DICE coefficient. The surface DICE
+ measures the overlap of two surfaces instead of two volumes. A surface
+ element is counted as overlapping (or touching), when the closest distance to
+ the other surface is less or equal to the specified tolerance. The DICE
+ coefficient is in the range between 0.0 (no overlap) to 1.0 (perfect overlap).
+ Args:
+ surface_distances: dict with "distances_gt_to_pred", "distances_pred_to_gt"
+ "surfel_areas_gt", "surfel_areas_pred" created by
+ compute_surface_distances()
+ tolerance_mm: a float value. The tolerance in mm
+ Returns:
+ A float value. The surface DICE coefficient in [0.0, 1.0].
+ """
+ distances_gt_to_pred = surface_distances["distances_gt_to_pred"]
+ distances_pred_to_gt = surface_distances["distances_pred_to_gt"]
+ surfel_areas_gt = surface_distances["surfel_areas_gt"]
+ surfel_areas_pred = surface_distances["surfel_areas_pred"]
+ overlap_gt = np.sum(surfel_areas_gt[distances_gt_to_pred <= tolerance_mm])
+ overlap_pred = np.sum(
+ surfel_areas_pred[distances_pred_to_gt <= tolerance_mm])
+ surface_dice = (overlap_gt + overlap_pred) / (
+ np.sum(surfel_areas_gt) + np.sum(surfel_areas_pred))
+ return surface_dice
+
+
+def compute_dice_coefficient(mask_gt, mask_pred):
+ """Computes soerensen-dice coefficient.
+ compute the soerensen-dice coefficient between the ground truth mask `mask_gt`
+ and the predicted mask `mask_pred`.
+ Args:
+ mask_gt: 3-dim Numpy array of type bool. The ground truth mask.
+ mask_pred: 3-dim Numpy array of type bool. The predicted mask.
+ Returns:
+ the dice coeffcient as float. If both masks are empty, the result is NaN.
+ """
+ volume_sum = mask_gt.sum() + mask_pred.sum()
+ if volume_sum == 0:
+ return np.NaN
+ volume_intersect = (mask_gt & mask_pred).sum()
+ return 2*volume_intersect / volume_sum
\ No newline at end of file
diff --git a/nnunet/__init__.py b/nnunet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e5288ed7a4ff4f0cfc9366cafb088b6bf0c6b1
--- /dev/null
+++ b/nnunet/__init__.py
@@ -0,0 +1,7 @@
+from __future__ import absolute_import
+print("\n\nPlease cite the following paper when using nnUNet:\n\nIsensee, F., Jaeger, P.F., Kohl, S.A.A. et al. "
+ "\"nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation.\" "
+ "Nat Methods (2020). https://doi.org/10.1038/s41592-020-01008-z\n\n")
+print("If you have questions or suggestions, feel free to open an issue at https://github.com/MIC-DKFZ/nnUNet\n")
+
+from . import *
\ No newline at end of file
diff --git a/nnunet/configuration.py b/nnunet/configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ed4d68f2f96003f73468c52606e61ad2f88cac
--- /dev/null
+++ b/nnunet/configuration.py
@@ -0,0 +1,5 @@
+import os
+
+default_num_threads = 8 if 'nnUNet_def_n_proc' not in os.environ else int(os.environ['nnUNet_def_n_proc'])
+RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD = 3 # determines what threshold to use for resampling the low resolution axis
+# separately (with NN)
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task017_BeyondCranialVaultAbdominalOrganSegmentation.py b/nnunet/dataset_conversion/Task017_BeyondCranialVaultAbdominalOrganSegmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5cf5ee87ce9ad73588ee4b826baa59c3f1cbed8
--- /dev/null
+++ b/nnunet/dataset_conversion/Task017_BeyondCranialVaultAbdominalOrganSegmentation.py
@@ -0,0 +1,94 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from nnunet.paths import nnUNet_raw_data
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+
+
+if __name__ == "__main__":
+ base = "/media/yunlu/10TB/research/other_data/Multi-Atlas Labeling Beyond the Cranial Vault/RawData/"
+
+ task_id = 17
+ task_name = "AbdominalOrganSegmentation"
+ prefix = 'ABD'
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_folder = join(base, "Training/img")
+ label_folder = join(base, "Training/label")
+ test_folder = join(base, "Test/img")
+ train_patient_names = []
+ test_patient_names = []
+ train_patients = subfiles(train_folder, join=False, suffix = 'nii.gz')
+ for p in train_patients:
+ serial_number = int(p[3:7])
+ train_patient_name = f'{prefix}_{serial_number:03d}.nii.gz'
+ label_file = join(label_folder, f'label{p[3:]}')
+ image_file = join(train_folder, p)
+ shutil.copy(image_file, join(imagestr, f'{train_patient_name[:7]}_0000.nii.gz'))
+ shutil.copy(label_file, join(labelstr, train_patient_name))
+ train_patient_names.append(train_patient_name)
+
+ test_patients = subfiles(test_folder, join=False, suffix=".nii.gz")
+ for p in test_patients:
+ p = p[:-7]
+ image_file = join(test_folder, p + ".nii.gz")
+ serial_number = int(p[3:7])
+ test_patient_name = f'{prefix}_{serial_number:03d}.nii.gz'
+ shutil.copy(image_file, join(imagests, f'{test_patient_name[:7]}_0000.nii.gz'))
+ test_patient_names.append(test_patient_name)
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "AbdominalOrganSegmentation"
+ json_dict['description'] = "Multi-Atlas Labeling Beyond the Cranial Vault Abdominal Organ Segmentation"
+ json_dict['tensorImageSize'] = "3D"
+ json_dict['reference'] = "https://www.synapse.org/#!Synapse:syn3193805/wiki/217789"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = OrderedDict({
+ "00": "background",
+ "01": "spleen",
+ "02": "right kidney",
+ "03": "left kidney",
+ "04": "gallbladder",
+ "05": "esophagus",
+ "06": "liver",
+ "07": "stomach",
+ "08": "aorta",
+ "09": "inferior vena cava",
+ "10": "portal vein and splenic vein",
+ "11": "pancreas",
+ "12": "right adrenal gland",
+ "13": "left adrenal gland"}
+ )
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s" % train_patient_name, "label": "./labelsTr/%s" % train_patient_name} for i, train_patient_name in enumerate(train_patient_names)]
+ json_dict['test'] = ["./imagesTs/%s" % test_patient_name for test_patient_name in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task024_Promise2012.py b/nnunet/dataset_conversion/Task024_Promise2012.py
new file mode 100644
index 0000000000000000000000000000000000000000..e090fa16eef4b2cbb2d1bb7c7324441f8472e77c
--- /dev/null
+++ b/nnunet/dataset_conversion/Task024_Promise2012.py
@@ -0,0 +1,81 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from collections import OrderedDict
+import SimpleITK as sitk
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def export_for_submission(source_dir, target_dir):
+ """
+ promise wants mhd :-/
+ :param source_dir:
+ :param target_dir:
+ :return:
+ """
+ files = subfiles(source_dir, suffix=".nii.gz", join=False)
+ target_files = [join(target_dir, i[:-7] + ".mhd") for i in files]
+ maybe_mkdir_p(target_dir)
+ for f, t in zip(files, target_files):
+ img = sitk.ReadImage(join(source_dir, f))
+ sitk.WriteImage(img, t)
+
+
+if __name__ == "__main__":
+ folder = "/media/fabian/My Book/datasets/promise2012"
+ out_folder = "/media/fabian/My Book/MedicalDecathlon/MedicalDecathlon_raw_splitted/Task024_Promise"
+
+ maybe_mkdir_p(join(out_folder, "imagesTr"))
+ maybe_mkdir_p(join(out_folder, "imagesTs"))
+ maybe_mkdir_p(join(out_folder, "labelsTr"))
+ # train
+ current_dir = join(folder, "train")
+ segmentations = subfiles(current_dir, suffix="segmentation.mhd")
+ raw_data = [i for i in subfiles(current_dir, suffix="mhd") if not i.endswith("segmentation.mhd")]
+ for i in raw_data:
+ out_fname = join(out_folder, "imagesTr", i.split("/")[-1][:-4] + "_0000.nii.gz")
+ sitk.WriteImage(sitk.ReadImage(i), out_fname)
+ for i in segmentations:
+ out_fname = join(out_folder, "labelsTr", i.split("/")[-1][:-17] + ".nii.gz")
+ sitk.WriteImage(sitk.ReadImage(i), out_fname)
+
+ # test
+ current_dir = join(folder, "test")
+ test_data = subfiles(current_dir, suffix="mhd")
+ for i in test_data:
+ out_fname = join(out_folder, "imagesTs", i.split("/")[-1][:-4] + "_0000.nii.gz")
+ sitk.WriteImage(sitk.ReadImage(i), out_fname)
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "PROMISE12"
+ json_dict['description'] = "prostate"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "MRI",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "prostate"
+ }
+ json_dict['numTraining'] = len(raw_data)
+ json_dict['numTest'] = len(test_data)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1][:-4], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1][:-4]} for i in
+ raw_data]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1][:-4] for i in test_data]
+
+ save_json(json_dict, os.path.join(out_folder, "dataset.json"))
+
diff --git a/nnunet/dataset_conversion/Task027_AutomaticCardiacDetectionChallenge.py b/nnunet/dataset_conversion/Task027_AutomaticCardiacDetectionChallenge.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dface40a494faef3c43a4e9b72f80d9d2f8ba45
--- /dev/null
+++ b/nnunet/dataset_conversion/Task027_AutomaticCardiacDetectionChallenge.py
@@ -0,0 +1,106 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+import numpy as np
+from sklearn.model_selection import KFold
+
+
+def convert_to_submission(source_dir, target_dir):
+ niftis = subfiles(source_dir, join=False, suffix=".nii.gz")
+ patientids = np.unique([i[:10] for i in niftis])
+ maybe_mkdir_p(target_dir)
+ for p in patientids:
+ files_of_that_patient = subfiles(source_dir, prefix=p, suffix=".nii.gz", join=False)
+ assert len(files_of_that_patient)
+ files_of_that_patient.sort()
+ # first is ED, second is ES
+ shutil.copy(join(source_dir, files_of_that_patient[0]), join(target_dir, p + "_ED.nii.gz"))
+ shutil.copy(join(source_dir, files_of_that_patient[1]), join(target_dir, p + "_ES.nii.gz"))
+
+
+if __name__ == "__main__":
+ folder = "/media/fabian/My Book/datasets/ACDC/training"
+ folder_test = "/media/fabian/My Book/datasets/ACDC/testing/testing"
+ out_folder = "/media/fabian/My Book/MedicalDecathlon/MedicalDecathlon_raw_splitted/Task027_ACDC"
+
+ maybe_mkdir_p(join(out_folder, "imagesTr"))
+ maybe_mkdir_p(join(out_folder, "imagesTs"))
+ maybe_mkdir_p(join(out_folder, "labelsTr"))
+
+ # train
+ all_train_files = []
+ patient_dirs_train = subfolders(folder, prefix="patient")
+ for p in patient_dirs_train:
+ current_dir = p
+ data_files_train = [i for i in subfiles(current_dir, suffix=".nii.gz") if i.find("_gt") == -1 and i.find("_4d") == -1]
+ corresponding_seg_files = [i[:-7] + "_gt.nii.gz" for i in data_files_train]
+ for d, s in zip(data_files_train, corresponding_seg_files):
+ patient_identifier = d.split("/")[-1][:-7]
+ all_train_files.append(patient_identifier + "_0000.nii.gz")
+ shutil.copy(d, join(out_folder, "imagesTr", patient_identifier + "_0000.nii.gz"))
+ shutil.copy(s, join(out_folder, "labelsTr", patient_identifier + ".nii.gz"))
+
+ # test
+ all_test_files = []
+ patient_dirs_test = subfolders(folder_test, prefix="patient")
+ for p in patient_dirs_test:
+ current_dir = p
+ data_files_test = [i for i in subfiles(current_dir, suffix=".nii.gz") if i.find("_gt") == -1 and i.find("_4d") == -1]
+ for d in data_files_test:
+ patient_identifier = d.split("/")[-1][:-7]
+ all_test_files.append(patient_identifier + "_0000.nii.gz")
+ shutil.copy(d, join(out_folder, "imagesTs", patient_identifier + "_0000.nii.gz"))
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "ACDC"
+ json_dict['description'] = "cardias cine MRI segmentation"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see ACDC challenge"
+ json_dict['licence'] = "see ACDC challenge"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "MRI",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "RV",
+ "2": "MLV",
+ "3": "LVC"
+ }
+ json_dict['numTraining'] = len(all_train_files)
+ json_dict['numTest'] = len(all_test_files)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1][:-12], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1][:-12]} for i in
+ all_train_files]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1][:-12] for i in all_test_files]
+
+ save_json(json_dict, os.path.join(out_folder, "dataset.json"))
+
+ # create a dummy split (patients need to be separated)
+ splits = []
+ patients = np.unique([i[:10] for i in all_train_files])
+ patientids = [i[:-12] for i in all_train_files]
+
+ kf = KFold(5, True, 12345)
+ for tr, val in kf.split(patients):
+ splits.append(OrderedDict())
+ tr_patients = patients[tr]
+ splits[-1]['train'] = [i[:-12] for i in all_train_files if i[:10] in tr_patients]
+ val_patients = patients[val]
+ splits[-1]['val'] = [i[:-12] for i in all_train_files if i[:10] in val_patients]
+
+ save_pickle(splits, "/media/fabian/nnunet/Task027_ACDC/splits_final.pkl")
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task029_LiverTumorSegmentationChallenge.py b/nnunet/dataset_conversion/Task029_LiverTumorSegmentationChallenge.py
new file mode 100644
index 0000000000000000000000000000000000000000..11bcdd1f9a731caac758f689988f1f6fadfa13d9
--- /dev/null
+++ b/nnunet/dataset_conversion/Task029_LiverTumorSegmentationChallenge.py
@@ -0,0 +1,123 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+import SimpleITK as sitk
+from batchgenerators.utilities.file_and_folder_operations import *
+from multiprocessing import Pool
+import numpy as np
+from nnunet.configuration import default_num_threads
+from scipy.ndimage import label
+
+
+def export_segmentations(indir, outdir):
+ niftis = subfiles(indir, suffix='nii.gz', join=False)
+ for n in niftis:
+ identifier = str(n.split("_")[-1][:-7])
+ outfname = join(outdir, "test-segmentation-%s.nii" % identifier)
+ img = sitk.ReadImage(join(indir, n))
+ sitk.WriteImage(img, outfname)
+
+
+def export_segmentations_postprocess(indir, outdir):
+ maybe_mkdir_p(outdir)
+ niftis = subfiles(indir, suffix='nii.gz', join=False)
+ for n in niftis:
+ print("\n", n)
+ identifier = str(n.split("_")[-1][:-7])
+ outfname = join(outdir, "test-segmentation-%s.nii" % identifier)
+ img = sitk.ReadImage(join(indir, n))
+ img_npy = sitk.GetArrayFromImage(img)
+ lmap, num_objects = label((img_npy > 0).astype(int))
+ sizes = []
+ for o in range(1, num_objects + 1):
+ sizes.append((lmap == o).sum())
+ mx = np.argmax(sizes) + 1
+ print(sizes)
+ img_npy[lmap != mx] = 0
+ img_new = sitk.GetImageFromArray(img_npy)
+ img_new.CopyInformation(img)
+ sitk.WriteImage(img_new, outfname)
+
+
+if __name__ == "__main__":
+ train_dir = "/media/fabian/DeepLearningData/tmp/LITS-Challenge-Train-Data"
+ test_dir = "/media/fabian/My Book/datasets/LiTS/test_data"
+
+
+ output_folder = "/media/fabian/My Book/MedicalDecathlon/MedicalDecathlon_raw_splitted/Task029_LITS"
+ img_dir = join(output_folder, "imagesTr")
+ lab_dir = join(output_folder, "labelsTr")
+ img_dir_te = join(output_folder, "imagesTs")
+ maybe_mkdir_p(img_dir)
+ maybe_mkdir_p(lab_dir)
+ maybe_mkdir_p(img_dir_te)
+
+
+ def load_save_train(args):
+ data_file, seg_file = args
+ pat_id = data_file.split("/")[-1]
+ pat_id = "train_" + pat_id.split("-")[-1][:-4]
+
+ img_itk = sitk.ReadImage(data_file)
+ sitk.WriteImage(img_itk, join(img_dir, pat_id + "_0000.nii.gz"))
+
+ img_itk = sitk.ReadImage(seg_file)
+ sitk.WriteImage(img_itk, join(lab_dir, pat_id + ".nii.gz"))
+ return pat_id
+
+ def load_save_test(args):
+ data_file = args
+ pat_id = data_file.split("/")[-1]
+ pat_id = "test_" + pat_id.split("-")[-1][:-4]
+
+ img_itk = sitk.ReadImage(data_file)
+ sitk.WriteImage(img_itk, join(img_dir_te, pat_id + "_0000.nii.gz"))
+ return pat_id
+
+ nii_files_tr_data = subfiles(train_dir, True, "volume", "nii", True)
+ nii_files_tr_seg = subfiles(train_dir, True, "segmen", "nii", True)
+
+ nii_files_ts = subfiles(test_dir, True, "test-volume", "nii", True)
+
+ p = Pool(default_num_threads)
+ train_ids = p.map(load_save_train, zip(nii_files_tr_data, nii_files_tr_seg))
+ test_ids = p.map(load_save_test, nii_files_ts)
+ p.close()
+ p.join()
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "LITS"
+ json_dict['description'] = "LITS"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT"
+ }
+
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "liver",
+ "2": "tumor"
+ }
+
+ json_dict['numTraining'] = len(train_ids)
+ json_dict['numTest'] = len(test_ids)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in train_ids]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_ids]
+
+ with open(os.path.join(output_folder, "dataset.json"), 'w') as f:
+ json.dump(json_dict, f, indent=4, sort_keys=True)
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task032_BraTS_2018.py b/nnunet/dataset_conversion/Task032_BraTS_2018.py
new file mode 100644
index 0000000000000000000000000000000000000000..1401ada9faa20376f7ef6aa06aebdded0b1ac481
--- /dev/null
+++ b/nnunet/dataset_conversion/Task032_BraTS_2018.py
@@ -0,0 +1,176 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from multiprocessing.pool import Pool
+
+import numpy as np
+from collections import OrderedDict
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.dataset_conversion.Task043_BraTS_2019 import copy_BraTS_segmentation_and_convert_labels
+from nnunet.paths import nnUNet_raw_data
+import SimpleITK as sitk
+import shutil
+
+
+def convert_labels_back_to_BraTS(seg: np.ndarray):
+ new_seg = np.zeros_like(seg)
+ new_seg[seg == 1] = 2
+ new_seg[seg == 3] = 4
+ new_seg[seg == 2] = 1
+ return new_seg
+
+
+def load_convert_save(filename, input_folder, output_folder):
+ a = sitk.ReadImage(join(input_folder, filename))
+ b = sitk.GetArrayFromImage(a)
+ c = convert_labels_back_to_BraTS(b)
+ d = sitk.GetImageFromArray(c)
+ d.CopyInformation(a)
+ sitk.WriteImage(d, join(output_folder, filename))
+
+
+def convert_labels_back_to_BraTS_2018_2019_convention(input_folder: str, output_folder: str, num_processes: int = 12):
+ """
+ reads all prediction files (nifti) in the input folder, converts the labels back to BraTS convention and saves the
+ result in output_folder
+ :param input_folder:
+ :param output_folder:
+ :return:
+ """
+ maybe_mkdir_p(output_folder)
+ nii = subfiles(input_folder, suffix='.nii.gz', join=False)
+ p = Pool(num_processes)
+ p.starmap(load_convert_save, zip(nii, [input_folder] * len(nii), [output_folder] * len(nii)))
+ p.close()
+ p.join()
+
+
+if __name__ == "__main__":
+ """
+ REMEMBER TO CONVERT LABELS BACK TO BRATS CONVENTION AFTER PREDICTION!
+ """
+
+ task_name = "Task032_BraTS2018"
+ downloaded_data_dir = "/home/fabian/Downloads/BraTS2018_train_val_test_data/MICCAI_BraTS_2018_Data_Training"
+
+ target_base = join(nnUNet_raw_data, task_name)
+ target_imagesTr = join(target_base, "imagesTr")
+ target_imagesVal = join(target_base, "imagesVal")
+ target_imagesTs = join(target_base, "imagesTs")
+ target_labelsTr = join(target_base, "labelsTr")
+
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_imagesVal)
+ maybe_mkdir_p(target_imagesTs)
+ maybe_mkdir_p(target_labelsTr)
+
+ patient_names = []
+ for tpe in ["HGG", "LGG"]:
+ cur = join(downloaded_data_dir, tpe)
+ for p in subdirs(cur, join=False):
+ patdir = join(cur, p)
+ patient_name = tpe + "__" + p
+ patient_names.append(patient_name)
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+ seg = join(patdir, p + "_seg.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ isfile(seg)
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTr, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTr, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTr, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTr, patient_name + "_0003.nii.gz"))
+
+ copy_BraTS_segmentation_and_convert_labels(seg, join(target_labelsTr, patient_name + ".nii.gz"))
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "BraTS2018"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see BraTS2018"
+ json_dict['licence'] = "see BraTS2019 license"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "T1",
+ "1": "T1ce",
+ "2": "T2",
+ "3": "FLAIR"
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "edema",
+ "2": "non-enhancing",
+ "3": "enhancing",
+ }
+ json_dict['numTraining'] = len(patient_names)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ patient_names]
+ json_dict['test'] = []
+
+ save_json(json_dict, join(target_base, "dataset.json"))
+
+ del tpe, cur
+ downloaded_data_dir = "/home/fabian/Downloads/BraTS2018_train_val_test_data/MICCAI_BraTS_2018_Data_Validation"
+
+ for p in subdirs(downloaded_data_dir, join=False):
+ patdir = join(downloaded_data_dir, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesVal, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesVal, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesVal, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesVal, patient_name + "_0003.nii.gz"))
+
+ downloaded_data_dir = "/home/fabian/Downloads/BraTS2018_train_val_test_data/MICCAI_BraTS_2018_Data_Testing_FIsensee"
+
+ for p in subdirs(downloaded_data_dir, join=False):
+ patdir = join(downloaded_data_dir, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTs, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTs, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTs, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTs, patient_name + "_0003.nii.gz"))
diff --git a/nnunet/dataset_conversion/Task035_ISBI_MSLesionSegmentationChallenge.py b/nnunet/dataset_conversion/Task035_ISBI_MSLesionSegmentationChallenge.py
new file mode 100644
index 0000000000000000000000000000000000000000..a71b2e91c7120f9d3ef9df1055e69053254bf142
--- /dev/null
+++ b/nnunet/dataset_conversion/Task035_ISBI_MSLesionSegmentationChallenge.py
@@ -0,0 +1,162 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+from collections import OrderedDict
+import numpy as np
+import SimpleITK as sitk
+import multiprocessing
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def convert_to_nii_gz(filename):
+ f = sitk.ReadImage(filename)
+ sitk.WriteImage(f, os.path.splitext(filename)[0] + ".nii.gz")
+ os.remove(filename)
+
+
+def convert_for_submission(source_dir, target_dir):
+ files = subfiles(source_dir, suffix=".nii.gz", join=False)
+ maybe_mkdir_p(target_dir)
+ for f in files:
+ splitted = f.split("__")
+ case_id = int(splitted[1])
+ timestep = int(splitted[2][:-7])
+ t = join(target_dir, "test%02d_%02d_nnUNet.nii" % (case_id, timestep))
+ img = sitk.ReadImage(join(source_dir, f))
+ sitk.WriteImage(img, t)
+
+
+if __name__ == "__main__":
+ # convert to nifti.gz
+ dirs = ['/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/imagesTr',
+ '/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/imagesTs',
+ '/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/labelsTr']
+
+ p = multiprocessing.Pool(3)
+
+ for d in dirs:
+ nii_files = subfiles(d, suffix='.nii')
+ p.map(convert_to_nii_gz, nii_files)
+
+ p.close()
+ p.join()
+
+
+ def rename_files(folder):
+ all_files = subfiles(folder, join=False)
+ # there are max 14 patients per folder, starting with 1
+ for patientid in range(1, 15):
+ # there are certainly no more than 10 time steps per patient, starting with 1
+ for t in range(1, 10):
+ patient_files = [i for i in all_files if i.find("%02.0d_%02.0d_" % (patientid, t)) != -1]
+ if not len(patient_files) == 4:
+ continue
+
+ flair_file = [i for i in patient_files if i.endswith("_flair_pp.nii.gz")][0]
+ mprage_file = [i for i in patient_files if i.endswith("_mprage_pp.nii.gz")][0]
+ pd_file = [i for i in patient_files if i.endswith("_pd_pp.nii.gz")][0]
+ t2_file = [i for i in patient_files if i.endswith("_t2_pp.nii.gz")][0]
+
+ os.rename(join(folder, flair_file), join(folder, "case__%02.0d__%02.0d_0000.nii.gz" % (patientid, t)))
+ os.rename(join(folder, mprage_file), join(folder, "case__%02.0d__%02.0d_0001.nii.gz" % (patientid, t)))
+ os.rename(join(folder, pd_file), join(folder, "case__%02.0d__%02.0d_0002.nii.gz" % (patientid, t)))
+ os.rename(join(folder, t2_file), join(folder, "case__%02.0d__%02.0d_0003.nii.gz" % (patientid, t)))
+
+
+ for d in dirs[:-1]:
+ rename_files(d)
+
+
+ # now we have to deal with the training masks, we do it the quick and dirty way here by just creating copies of the
+ # training data
+
+ train_folder = '/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/imagesTr'
+
+ for patientid in range(1, 6):
+ for t in range(1, 6):
+ fnames_original = subfiles(train_folder, prefix="case__%02.0d__%02.0d" % (patientid, t), suffix=".nii.gz", sort=True)
+ for f in fnames_original:
+ for mask in [1, 2]:
+ fname_target = f[:-12] + "__mask%d" % mask + f[-12:]
+ shutil.copy(f, fname_target)
+ os.remove(f)
+
+
+ labels_folder = '/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/labelsTr'
+
+ for patientid in range(1, 6):
+ for t in range(1, 6):
+ for mask in [1, 2]:
+ f = join(labels_folder, "training%02d_%02d_mask%d.nii.gz" % (patientid, t, mask))
+ if isfile(f):
+ os.rename(f, join(labels_folder, "case__%02.0d__%02.0d__mask%d.nii.gz" % (patientid, t, mask)))
+
+
+
+ tr_files = []
+ for patientid in range(1, 6):
+ for t in range(1, 6):
+ for mask in [1, 2]:
+ if isfile(join(labels_folder, "case__%02.0d__%02.0d__mask%d.nii.gz" % (patientid, t, mask))):
+ tr_files.append("case__%02.0d__%02.0d__mask%d.nii.gz" % (patientid, t, mask))
+
+
+ ts_files = []
+ for patientid in range(1, 20):
+ for t in range(1, 20):
+ if isfile(join("/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/imagesTs",
+ "case__%02.0d__%02.0d_0000.nii.gz" % (patientid, t))):
+ ts_files.append("case__%02.0d__%02.0d.nii.gz" % (patientid, t))
+
+
+ out_base = '/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/'
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "ISBI_Lesion_Segmentation_Challenge_2015"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "flair",
+ "1": "mprage",
+ "2": "pd",
+ "3": "t2"
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "lesion"
+ }
+ json_dict['numTraining'] = len(subfiles(labels_folder))
+ json_dict['numTest'] = len(subfiles('/media/fabian/My Book/MedicalDecathlon/Task035_ISBILesionSegmentation/imagesTs')) // 4
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i[:-7], "label": "./labelsTr/%s.nii.gz" % i[:-7]} for i in
+ tr_files]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i[:-7] for i in ts_files]
+
+ save_json(json_dict, join(out_base, "dataset.json"))
+
+ case_identifiers = np.unique([i[:-12] for i in subfiles("/media/fabian/My Book/MedicalDecathlon/MedicalDecathlon_raw_splitted/Task035_ISBILesionSegmentation/imagesTr", suffix='.nii.gz', join=False)])
+
+ splits = []
+ for f in range(5):
+ cases = [i for i in range(1, 6) if i != f+1]
+ splits.append(OrderedDict())
+ splits[-1]['val'] = np.array([i for i in case_identifiers if i.startswith("case__%02d__" % (f + 1))])
+ remaining = [i for i in case_identifiers if i not in splits[-1]['val']]
+ splits[-1]['train'] = np.array(remaining)
+
+ maybe_mkdir_p("/media/fabian/nnunet/Task035_ISBILesionSegmentation")
+ save_pickle(splits, join("/media/fabian/nnunet/Task035_ISBILesionSegmentation", "splits_final.pkl"))
diff --git a/nnunet/dataset_conversion/Task037_038_Chaos_Challenge.py b/nnunet/dataset_conversion/Task037_038_Chaos_Challenge.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fa3dcc85dfa2671cf9e21bac133fe8ef9b76e60
--- /dev/null
+++ b/nnunet/dataset_conversion/Task037_038_Chaos_Challenge.py
@@ -0,0 +1,460 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from PIL import Image
+import shutil
+from collections import OrderedDict
+
+import dicom2nifti
+import numpy as np
+from batchgenerators.utilities.data_splitting import get_split_deterministic
+from batchgenerators.utilities.file_and_folder_operations import *
+from PIL import Image
+import SimpleITK as sitk
+from nnunet.paths import preprocessing_output_dir, nnUNet_raw_data
+from nnunet.utilities.sitk_stuff import copy_geometry
+from nnunet.inference.ensemble_predictions import merge
+
+
+def load_png_stack(folder):
+ pngs = subfiles(folder, suffix="png")
+ pngs.sort()
+ loaded = []
+ for p in pngs:
+ loaded.append(np.array(Image.open(p)))
+ loaded = np.stack(loaded, 0)[::-1]
+ return loaded
+
+
+def convert_CT_seg(loaded_png):
+ return loaded_png.astype(np.uint16)
+
+
+def convert_MR_seg(loaded_png):
+ result = np.zeros(loaded_png.shape)
+ result[(loaded_png > 55) & (loaded_png <= 70)] = 1 # liver
+ result[(loaded_png > 110) & (loaded_png <= 135)] = 2 # right kidney
+ result[(loaded_png > 175) & (loaded_png <= 200)] = 3 # left kidney
+ result[(loaded_png > 240) & (loaded_png <= 255)] = 4 # spleen
+ return result
+
+
+def convert_seg_to_intensity_task5(seg):
+ seg_new = np.zeros(seg.shape, dtype=np.uint8)
+ seg_new[seg == 1] = 63
+ seg_new[seg == 2] = 126
+ seg_new[seg == 3] = 189
+ seg_new[seg == 4] = 252
+ return seg_new
+
+
+def convert_seg_to_intensity_task3(seg):
+ seg_new = np.zeros(seg.shape, dtype=np.uint8)
+ seg_new[seg == 1] = 63
+ return seg_new
+
+
+def write_pngs_from_nifti(nifti, output_folder, converter=convert_seg_to_intensity_task3):
+ npy = sitk.GetArrayFromImage(sitk.ReadImage(nifti))
+ seg_new = converter(npy)
+ for z in range(len(npy)):
+ Image.fromarray(seg_new[z]).save(join(output_folder, "img%03.0d.png" % z))
+
+
+def convert_variant2_predicted_test_to_submission_format(folder_with_predictions,
+ output_folder="/home/fabian/drives/datasets/results/nnUNet/test_sets/Task038_CHAOS_Task_3_5_Variant2/ready_to_submit",
+ postprocessing_file="/home/fabian/drives/datasets/results/nnUNet/ensembles/Task038_CHAOS_Task_3_5_Variant2/ensemble_2d__nnUNetTrainerV2__nnUNetPlansv2.1--3d_fullres__nnUNetTrainerV2__nnUNetPlansv2.1/postprocessing.json"):
+ """
+ output_folder is where the extracted template is
+ :param folder_with_predictions:
+ :param output_folder:
+ :return:
+ """
+ postprocessing_file = "/media/fabian/Results/nnUNet/3d_fullres/Task039_CHAOS_Task_3_5_Variant2_highres/" \
+ "nnUNetTrainerV2__nnUNetPlansfixed/postprocessing.json"
+
+ # variant 2 treats in and out phase as two training examples, so we need to ensemble these two again
+ final_predictions_folder = join(output_folder, "final")
+ maybe_mkdir_p(final_predictions_folder)
+ t1_patient_names = [i.split("_")[-1][:-7] for i in subfiles(folder_with_predictions, prefix="T1", suffix=".nii.gz", join=False)]
+ folder_for_ensembing0 = join(output_folder, "ens0")
+ folder_for_ensembing1 = join(output_folder, "ens1")
+ maybe_mkdir_p(folder_for_ensembing0)
+ maybe_mkdir_p(folder_for_ensembing1)
+ # now copy all t1 out phases in ens0 and all in phases in ens1. Name them the same.
+ for t1 in t1_patient_names:
+ shutil.copy(join(folder_with_predictions, "T1_in_%s.npz" % t1), join(folder_for_ensembing1, "T1_%s.npz" % t1))
+ shutil.copy(join(folder_with_predictions, "T1_in_%s.pkl" % t1), join(folder_for_ensembing1, "T1_%s.pkl" % t1))
+ shutil.copy(join(folder_with_predictions, "T1_out_%s.npz" % t1), join(folder_for_ensembing0, "T1_%s.npz" % t1))
+ shutil.copy(join(folder_with_predictions, "T1_out_%s.pkl" % t1), join(folder_for_ensembing0, "T1_%s.pkl" % t1))
+ shutil.copy(join(folder_with_predictions, "plans.pkl"), join(folder_for_ensembing0, "plans.pkl"))
+ shutil.copy(join(folder_with_predictions, "plans.pkl"), join(folder_for_ensembing1, "plans.pkl"))
+
+ # there is a problem with T1_35 that I need to correct manually (different crop size, will not negatively impact results)
+ #ens0_softmax = np.load(join(folder_for_ensembing0, "T1_35.npz"))['softmax']
+ ens1_softmax = np.load(join(folder_for_ensembing1, "T1_35.npz"))['softmax']
+ #ens0_props = load_pickle(join(folder_for_ensembing0, "T1_35.pkl"))
+ #ens1_props = load_pickle(join(folder_for_ensembing1, "T1_35.pkl"))
+ ens1_softmax = ens1_softmax[:, :, :-1, :]
+ np.savez_compressed(join(folder_for_ensembing1, "T1_35.npz"), softmax=ens1_softmax)
+ shutil.copy(join(folder_for_ensembing0, "T1_35.pkl"), join(folder_for_ensembing1, "T1_35.pkl"))
+
+ # now call my ensemble function
+ merge((folder_for_ensembing0, folder_for_ensembing1), final_predictions_folder, 8, True,
+ postprocessing_file=postprocessing_file)
+ # copy t2 files to final_predictions_folder as well
+ t2_files = subfiles(folder_with_predictions, prefix="T2", suffix=".nii.gz", join=False)
+ for t2 in t2_files:
+ shutil.copy(join(folder_with_predictions, t2), join(final_predictions_folder, t2))
+
+ # apply postprocessing
+ from nnunet.postprocessing.connected_components import apply_postprocessing_to_folder, load_postprocessing
+ postprocessed_folder = join(output_folder, "final_postprocessed")
+ for_which_classes, min_valid_obj_size = load_postprocessing(postprocessing_file)
+ apply_postprocessing_to_folder(final_predictions_folder, postprocessed_folder,
+ for_which_classes, min_valid_obj_size, 8)
+
+ # now export the niftis in the weird png format
+ # task 3
+ output_dir = join(output_folder, "CHAOS_submission_template_new", "Task3", "MR")
+ for t1 in t1_patient_names:
+ output_folder_here = join(output_dir, t1, "T1DUAL", "Results")
+ nifti_file = join(postprocessed_folder, "T1_%s.nii.gz" % t1)
+ write_pngs_from_nifti(nifti_file, output_folder_here, converter=convert_seg_to_intensity_task3)
+ for t2 in t2_files:
+ patname = t2.split("_")[-1][:-7]
+ output_folder_here = join(output_dir, patname, "T2SPIR", "Results")
+ nifti_file = join(postprocessed_folder, "T2_%s.nii.gz" % patname)
+ write_pngs_from_nifti(nifti_file, output_folder_here, converter=convert_seg_to_intensity_task3)
+
+ # task 5
+ output_dir = join(output_folder, "CHAOS_submission_template_new", "Task5", "MR")
+ for t1 in t1_patient_names:
+ output_folder_here = join(output_dir, t1, "T1DUAL", "Results")
+ nifti_file = join(postprocessed_folder, "T1_%s.nii.gz" % t1)
+ write_pngs_from_nifti(nifti_file, output_folder_here, converter=convert_seg_to_intensity_task5)
+ for t2 in t2_files:
+ patname = t2.split("_")[-1][:-7]
+ output_folder_here = join(output_dir, patname, "T2SPIR", "Results")
+ nifti_file = join(postprocessed_folder, "T2_%s.nii.gz" % patname)
+ write_pngs_from_nifti(nifti_file, output_folder_here, converter=convert_seg_to_intensity_task5)
+
+
+
+if __name__ == "__main__":
+ """
+ This script only prepares data to participate in Task 5 and Task 5. I don't like the CT task because
+ 1) there are
+ no abdominal organs in the ground truth. In the case of CT we are supposed to train only liver while on MRI we are
+ supposed to train all organs. This would require manual modification of nnU-net to deal with this dataset. This is
+ not what nnU-net is about.
+ 2) CT Liver or multiorgan segmentation is too easy to get external data for. Therefore the challenges comes down
+ to who gets the b est external data, not who has the best algorithm. Not super interesting.
+
+ Task 3 is a subtask of Task 5 so we need to prepare the data only once.
+ Difficulty: We need to process both T1 and T2, but T1 has 2 'modalities' (phases). nnU-Net cannot handly varying
+ number of input channels. We need to be creative.
+ We deal with this by preparing 2 Variants:
+ 1) pretend we have 2 modalities for T2 as well by simply stacking a copy of the data
+ 2) treat all MRI sequences independently, so we now have 3*20 training data instead of 2*20. In inference we then
+ ensemble the results for the two t1 modalities.
+
+ Careful: We need to split manually here to ensure we stratify by patient
+ """
+
+ root = "/media/fabian/My Book/datasets/CHAOS_challenge/Train_Sets"
+ root_test = "/media/fabian/My Book/datasets/CHAOS_challenge/Test_Sets"
+ out_base = nnUNet_raw_data
+ # CT
+ # we ignore CT because
+
+ ##############################################################
+ # Variant 1
+ ##############################################################
+ patient_ids = []
+ patient_ids_test = []
+
+ output_folder = join(out_base, "Task037_CHAOS_Task_3_5_Variant1")
+ output_images = join(output_folder, "imagesTr")
+ output_labels = join(output_folder, "labelsTr")
+ output_imagesTs = join(output_folder, "imagesTs")
+ maybe_mkdir_p(output_images)
+ maybe_mkdir_p(output_labels)
+ maybe_mkdir_p(output_imagesTs)
+
+
+ # Process T1 train
+ d = join(root, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T1_" + p
+ gt_dir = join(d, p, "T1DUAL", "Ground")
+ seg = convert_MR_seg(load_png_stack(gt_dir)[::-1])
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "InPhase")
+ img_outfile = join(output_images, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "OutPhase")
+ img_outfile = join(output_images, patient_name + "_0001.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ seg_itk = sitk.GetImageFromArray(seg.astype(np.uint8))
+ seg_itk = copy_geometry(seg_itk, img_sitk)
+ sitk.WriteImage(seg_itk, join(output_labels, patient_name + ".nii.gz"))
+ patient_ids.append(patient_name)
+
+ # Process T1 test
+ d = join(root_test, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T1_" + p
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "InPhase")
+ img_outfile = join(output_imagesTs, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "OutPhase")
+ img_outfile = join(output_imagesTs, patient_name + "_0001.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ patient_ids_test.append(patient_name)
+
+ # Process T2 train
+ d = join(root, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T2_" + p
+
+ gt_dir = join(d, p, "T2SPIR", "Ground")
+ seg = convert_MR_seg(load_png_stack(gt_dir)[::-1])
+
+ img_dir = join(d, p, "T2SPIR", "DICOM_anon")
+ img_outfile = join(output_images, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+ shutil.copy(join(output_images, patient_name + "_0000.nii.gz"), join(output_images, patient_name + "_0001.nii.gz"))
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ seg_itk = sitk.GetImageFromArray(seg.astype(np.uint8))
+ seg_itk = copy_geometry(seg_itk, img_sitk)
+ sitk.WriteImage(seg_itk, join(output_labels, patient_name + ".nii.gz"))
+ patient_ids.append(patient_name)
+
+ # Process T2 test
+ d = join(root_test, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T2_" + p
+
+ gt_dir = join(d, p, "T2SPIR", "Ground")
+
+ img_dir = join(d, p, "T2SPIR", "DICOM_anon")
+ img_outfile = join(output_imagesTs, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+ shutil.copy(join(output_imagesTs, patient_name + "_0000.nii.gz"), join(output_imagesTs, patient_name + "_0001.nii.gz"))
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ patient_ids_test.append(patient_name)
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "Chaos Challenge Task3/5 Variant 1"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "https://chaos.grand-challenge.org/Data/"
+ json_dict['licence'] = "see https://chaos.grand-challenge.org/Data/"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "MRI",
+ "1": "MRI",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "liver",
+ "2": "right kidney",
+ "3": "left kidney",
+ "4": "spleen",
+ }
+ json_dict['numTraining'] = len(patient_ids)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ patient_ids]
+ json_dict['test'] = []
+
+ save_json(json_dict, join(output_folder, "dataset.json"))
+
+ ##############################################################
+ # Variant 2
+ ##############################################################
+
+ patient_ids = []
+ patient_ids_test = []
+
+ output_folder = join(out_base, "Task038_CHAOS_Task_3_5_Variant2")
+ output_images = join(output_folder, "imagesTr")
+ output_imagesTs = join(output_folder, "imagesTs")
+ output_labels = join(output_folder, "labelsTr")
+ maybe_mkdir_p(output_images)
+ maybe_mkdir_p(output_imagesTs)
+ maybe_mkdir_p(output_labels)
+
+ # Process T1 train
+ d = join(root, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name_in = "T1_in_" + p
+ patient_name_out = "T1_out_" + p
+ gt_dir = join(d, p, "T1DUAL", "Ground")
+ seg = convert_MR_seg(load_png_stack(gt_dir)[::-1])
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "InPhase")
+ img_outfile = join(output_images, patient_name_in + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "OutPhase")
+ img_outfile = join(output_images, patient_name_out + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ seg_itk = sitk.GetImageFromArray(seg.astype(np.uint8))
+ seg_itk = copy_geometry(seg_itk, img_sitk)
+ sitk.WriteImage(seg_itk, join(output_labels, patient_name_in + ".nii.gz"))
+ sitk.WriteImage(seg_itk, join(output_labels, patient_name_out + ".nii.gz"))
+ patient_ids.append(patient_name_out)
+ patient_ids.append(patient_name_in)
+
+ # Process T1 test
+ d = join(root_test, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name_in = "T1_in_" + p
+ patient_name_out = "T1_out_" + p
+ gt_dir = join(d, p, "T1DUAL", "Ground")
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "InPhase")
+ img_outfile = join(output_imagesTs, patient_name_in + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_dir = join(d, p, "T1DUAL", "DICOM_anon", "OutPhase")
+ img_outfile = join(output_imagesTs, patient_name_out + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ patient_ids_test.append(patient_name_out)
+ patient_ids_test.append(patient_name_in)
+
+ # Process T2 train
+ d = join(root, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T2_" + p
+
+ gt_dir = join(d, p, "T2SPIR", "Ground")
+ seg = convert_MR_seg(load_png_stack(gt_dir)[::-1])
+
+ img_dir = join(d, p, "T2SPIR", "DICOM_anon")
+ img_outfile = join(output_images, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ seg_itk = sitk.GetImageFromArray(seg.astype(np.uint8))
+ seg_itk = copy_geometry(seg_itk, img_sitk)
+ sitk.WriteImage(seg_itk, join(output_labels, patient_name + ".nii.gz"))
+ patient_ids.append(patient_name)
+
+ # Process T2 test
+ d = join(root_test, "MR")
+ patients = subdirs(d, join=False)
+ for p in patients:
+ patient_name = "T2_" + p
+
+ gt_dir = join(d, p, "T2SPIR", "Ground")
+
+ img_dir = join(d, p, "T2SPIR", "DICOM_anon")
+ img_outfile = join(output_imagesTs, patient_name + "_0000.nii.gz")
+ _ = dicom2nifti.convert_dicom.dicom_series_to_nifti(img_dir, img_outfile, reorient_nifti=False)
+
+ img_sitk = sitk.ReadImage(img_outfile)
+ img_sitk_npy = sitk.GetArrayFromImage(img_sitk)
+ patient_ids_test.append(patient_name)
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "Chaos Challenge Task3/5 Variant 2"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "https://chaos.grand-challenge.org/Data/"
+ json_dict['licence'] = "see https://chaos.grand-challenge.org/Data/"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "MRI",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "liver",
+ "2": "right kidney",
+ "3": "left kidney",
+ "4": "spleen",
+ }
+ json_dict['numTraining'] = len(patient_ids)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ patient_ids]
+ json_dict['test'] = []
+
+ save_json(json_dict, join(output_folder, "dataset.json"))
+
+ #################################################
+ # custom split
+ #################################################
+ patients = subdirs(join(root, "MR"), join=False)
+ task_name_variant1 = "Task037_CHAOS_Task_3_5_Variant1"
+ task_name_variant2 = "Task038_CHAOS_Task_3_5_Variant2"
+
+ output_preprocessed_v1 = join(preprocessing_output_dir, task_name_variant1)
+ maybe_mkdir_p(output_preprocessed_v1)
+
+ output_preprocessed_v2 = join(preprocessing_output_dir, task_name_variant2)
+ maybe_mkdir_p(output_preprocessed_v2)
+
+ splits = []
+ for fold in range(5):
+ tr, val = get_split_deterministic(patients, fold, 5, 12345)
+ train = ["T2_" + i for i in tr] + ["T1_" + i for i in tr]
+ validation = ["T2_" + i for i in val] + ["T1_" + i for i in val]
+ splits.append({
+ 'train': train,
+ 'val': validation
+ })
+ save_pickle(splits, join(output_preprocessed_v1, "splits_final.pkl"))
+
+ splits = []
+ for fold in range(5):
+ tr, val = get_split_deterministic(patients, fold, 5, 12345)
+ train = ["T2_" + i for i in tr] + ["T1_in_" + i for i in tr] + ["T1_out_" + i for i in tr]
+ validation = ["T2_" + i for i in val] + ["T1_in_" + i for i in val] + ["T1_out_" + i for i in val]
+ splits.append({
+ 'train': train,
+ 'val': validation
+ })
+ save_pickle(splits, join(output_preprocessed_v2, "splits_final.pkl"))
+
diff --git a/nnunet/dataset_conversion/Task040_KiTS.py b/nnunet/dataset_conversion/Task040_KiTS.py
new file mode 100644
index 0000000000000000000000000000000000000000..e045e3aadc7a8c64bddeff7ff5be416855894e3b
--- /dev/null
+++ b/nnunet/dataset_conversion/Task040_KiTS.py
@@ -0,0 +1,240 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+import SimpleITK as sitk
+from multiprocessing import Pool
+from medpy.metric import dc
+import numpy as np
+from nnunet.paths import network_training_output_dir
+from scipy.ndimage import label
+
+
+def compute_dice_scores(ref: str, pred: str):
+ ref = sitk.GetArrayFromImage(sitk.ReadImage(ref))
+ pred = sitk.GetArrayFromImage(sitk.ReadImage(pred))
+ kidney_mask_ref = ref > 0
+ kidney_mask_pred = pred > 0
+ if np.sum(kidney_mask_pred) == 0 and kidney_mask_ref.sum() == 0:
+ kidney_dice = np.nan
+ else:
+ kidney_dice = dc(kidney_mask_pred, kidney_mask_ref)
+
+ tumor_mask_ref = ref == 2
+ tumor_mask_pred = pred == 2
+ if np.sum(tumor_mask_ref) == 0 and tumor_mask_pred.sum() == 0:
+ tumor_dice = np.nan
+ else:
+ tumor_dice = dc(tumor_mask_ref, tumor_mask_pred)
+
+ geometric_mean = np.mean((kidney_dice, tumor_dice))
+ return kidney_dice, tumor_dice, geometric_mean
+
+
+def evaluate_folder(folder_gt: str, folder_pred: str):
+ p = Pool(8)
+ niftis = subfiles(folder_gt, suffix=".nii.gz", join=False)
+ images_gt = [join(folder_gt, i) for i in niftis]
+ images_pred = [join(folder_pred, i) for i in niftis]
+ results = p.starmap(compute_dice_scores, zip(images_gt, images_pred))
+ p.close()
+ p.join()
+
+ with open(join(folder_pred, "results.csv"), 'w') as f:
+ for i, ni in enumerate(niftis):
+ f.write("%s,%0.4f,%0.4f,%0.4f\n" % (ni, *results[i]))
+
+
+def remove_all_but_the_two_largest_conn_comp(img_itk_file: str, file_out: str):
+ """
+ This was not used. I was just curious because others used this. Turns out this is not necessary for my networks
+ """
+ img_itk = sitk.ReadImage(img_itk_file)
+ img_npy = sitk.GetArrayFromImage(img_itk)
+
+ labelmap, num_labels = label((img_npy > 0).astype(int))
+
+ if num_labels > 2:
+ label_sizes = []
+ for i in range(1, num_labels + 1):
+ label_sizes.append(np.sum(labelmap == i))
+ argsrt = np.argsort(label_sizes)[::-1] # two largest are now argsrt[0] and argsrt[1]
+ keep_mask = (labelmap == argsrt[0] + 1) | (labelmap == argsrt[1] + 1)
+ img_npy[~keep_mask] = 0
+ new = sitk.GetImageFromArray(img_npy)
+ new.CopyInformation(img_itk)
+ sitk.WriteImage(new, file_out)
+ print(os.path.basename(img_itk_file), num_labels, label_sizes)
+ else:
+ shutil.copy(img_itk_file, file_out)
+
+
+def manual_postprocess(folder_in,
+ folder_out):
+ """
+ This was not used. I was just curious because others used this. Turns out this is not necessary for my networks
+ """
+ maybe_mkdir_p(folder_out)
+ infiles = subfiles(folder_in, suffix=".nii.gz", join=False)
+
+ outfiles = [join(folder_out, i) for i in infiles]
+ infiles = [join(folder_in, i) for i in infiles]
+
+ p = Pool(8)
+ _ = p.starmap_async(remove_all_but_the_two_largest_conn_comp, zip(infiles, outfiles))
+ _ = _.get()
+ p.close()
+ p.join()
+
+
+
+
+def copy_npz_fom_valsets():
+ '''
+ this is preparation for ensembling
+ :return:
+ '''
+ base = join(network_training_output_dir, "3d_lowres/Task048_KiTS_clean")
+ folders = ['nnUNetTrainerNewCandidate23_FabiansPreActResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23_FabiansResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23__nnUNetPlans']
+ for f in folders:
+ out = join(base, f, 'crossval_npz')
+ maybe_mkdir_p(out)
+ shutil.copy(join(base, f, 'plans.pkl'), out)
+ for fold in range(5):
+ cur = join(base, f, 'fold_%d' % fold, 'validation_raw')
+ npz_files = subfiles(cur, suffix='.npz', join=False)
+ pkl_files = [i[:-3] + 'pkl' for i in npz_files]
+ assert all([isfile(join(cur, i)) for i in pkl_files])
+ for n in npz_files:
+ corresponding_pkl = n[:-3] + 'pkl'
+ shutil.copy(join(cur, n), out)
+ shutil.copy(join(cur, corresponding_pkl), out)
+
+
+def ensemble(experiments=('nnUNetTrainerNewCandidate23_FabiansPreActResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23_FabiansResNet__nnUNetPlans'), out_dir="/media/fabian/Results/nnUNet/3d_lowres/Task048_KiTS_clean/ensemble_preactres_and_res"):
+ from nnunet.inference.ensemble_predictions import merge
+ folders = [join(network_training_output_dir, "3d_lowres/Task048_KiTS_clean", i, 'crossval_npz') for i in experiments]
+ merge(folders, out_dir, 8)
+
+
+def prepare_submission(fld= "/home/fabian/drives/datasets/results/nnUNet/test_sets/Task048_KiTS_clean/predicted_ens_3d_fullres_3d_cascade_fullres_postprocessed", # '/home/fabian/datasets_fabian/predicted_KiTS_nnUNetTrainerNewCandidate23_FabiansResNet',
+ out='/home/fabian/drives/datasets/results/nnUNet/test_sets/Task048_KiTS_clean/submission'):
+ nii = subfiles(fld, join=False, suffix='.nii.gz')
+ maybe_mkdir_p(out)
+ for n in nii:
+ outfname = n.replace('case', 'prediction')
+ shutil.copy(join(fld, n), join(out, outfname))
+
+
+def pretent_to_be_nnUNetTrainer(base, folds=(0, 1, 2, 3, 4)):
+ """
+ changes best checkpoint pickle nnunettrainer class name to nnUNetTrainer
+ :param experiments:
+ :return:
+ """
+ for fold in folds:
+ cur = join(base, "fold_%d" % fold)
+ pkl_file = join(cur, 'model_best.model.pkl')
+ a = load_pickle(pkl_file)
+ a['name_old'] = deepcopy(a['name'])
+ a['name'] = 'nnUNetTrainer'
+ save_pickle(a, pkl_file)
+
+
+def reset_trainerName(base, folds=(0, 1, 2, 3, 4)):
+ for fold in folds:
+ cur = join(base, "fold_%d" % fold)
+ pkl_file = join(cur, 'model_best.model.pkl')
+ a = load_pickle(pkl_file)
+ a['name'] = a['name_old']
+ del a['name_old']
+ save_pickle(a, pkl_file)
+
+
+def nnUNetTrainer_these(experiments=('nnUNetTrainerNewCandidate23_FabiansPreActResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23_FabiansResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23__nnUNetPlans')):
+ """
+ changes best checkpoint pickle nnunettrainer class name to nnUNetTrainer
+ :param experiments:
+ :return:
+ """
+ base = join(network_training_output_dir, "3d_lowres/Task048_KiTS_clean")
+ for exp in experiments:
+ cur = join(base, exp)
+ pretent_to_be_nnUNetTrainer(cur)
+
+
+def reset_trainerName_these(experiments=('nnUNetTrainerNewCandidate23_FabiansPreActResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23_FabiansResNet__nnUNetPlans',
+ 'nnUNetTrainerNewCandidate23__nnUNetPlans')):
+ """
+ changes best checkpoint pickle nnunettrainer class name to nnUNetTrainer
+ :param experiments:
+ :return:
+ """
+ base = join(network_training_output_dir, "3d_lowres/Task048_KiTS_clean")
+ for exp in experiments:
+ cur = join(base, exp)
+ reset_trainerName(cur)
+
+
+if __name__ == "__main__":
+ base = "/media/fabian/My Book/datasets/KiTS2019_Challenge/kits19/data"
+ out = "/media/fabian/My Book/MedicalDecathlon/nnUNet_raw_splitted/Task040_KiTS"
+ cases = subdirs(base, join=False)
+
+ maybe_mkdir_p(out)
+ maybe_mkdir_p(join(out, "imagesTr"))
+ maybe_mkdir_p(join(out, "imagesTs"))
+ maybe_mkdir_p(join(out, "labelsTr"))
+
+ for c in cases:
+ case_id = int(c.split("_")[-1])
+ if case_id < 210:
+ shutil.copy(join(base, c, "imaging.nii.gz"), join(out, "imagesTr", c + "_0000.nii.gz"))
+ shutil.copy(join(base, c, "segmentation.nii.gz"), join(out, "labelsTr", c + ".nii.gz"))
+ else:
+ shutil.copy(join(base, c, "imaging.nii.gz"), join(out, "imagesTs", c + "_0000.nii.gz"))
+
+ json_dict = {}
+ json_dict['name'] = "KiTS"
+ json_dict['description'] = "kidney and kidney tumor segmentation"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "KiTS data for nnunet"
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "Kidney",
+ "2": "Tumor"
+ }
+ json_dict['numTraining'] = len(cases)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ cases]
+ json_dict['test'] = []
+
+ save_json(json_dict, os.path.join(out, "dataset.json"))
+
diff --git a/nnunet/dataset_conversion/Task043_BraTS_2019.py b/nnunet/dataset_conversion/Task043_BraTS_2019.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e2ea37ec235cfccfffb4117026d44fe6caee3fa
--- /dev/null
+++ b/nnunet/dataset_conversion/Task043_BraTS_2019.py
@@ -0,0 +1,164 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from collections import OrderedDict
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+import SimpleITK as sitk
+import shutil
+
+
+def copy_BraTS_segmentation_and_convert_labels(in_file, out_file):
+ # use this for segmentation only!!!
+ # nnUNet wants the labels to be continuous. BraTS is 0, 1, 2, 4 -> we make that into 0, 1, 2, 3
+ img = sitk.ReadImage(in_file)
+ img_npy = sitk.GetArrayFromImage(img)
+
+ uniques = np.unique(img_npy)
+ for u in uniques:
+ if u not in [0, 1, 2, 4]:
+ raise RuntimeError('unexpected label')
+
+ seg_new = np.zeros_like(img_npy)
+ seg_new[img_npy == 4] = 3
+ seg_new[img_npy == 2] = 1
+ seg_new[img_npy == 1] = 2
+ img_corr = sitk.GetImageFromArray(seg_new)
+ img_corr.CopyInformation(img)
+ sitk.WriteImage(img_corr, out_file)
+
+
+if __name__ == "__main__":
+ """
+ REMEMBER TO CONVERT LABELS BACK TO BRATS CONVENTION AFTER PREDICTION!
+ """
+
+ task_name = "Task043_BraTS2019"
+ downloaded_data_dir = "/home/sdp/MLPERF/Brats2019_DATA/MICCAI_BraTS_2019_Data_Training"
+
+ target_base = join(nnUNet_raw_data, task_name)
+ target_imagesTr = join(target_base, "imagesTr")
+ target_imagesVal = join(target_base, "imagesVal")
+ target_imagesTs = join(target_base, "imagesTs")
+ target_labelsTr = join(target_base, "labelsTr")
+
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_imagesVal)
+ maybe_mkdir_p(target_imagesTs)
+ maybe_mkdir_p(target_labelsTr)
+
+ patient_names = []
+ for tpe in ["HGG", "LGG"]:
+ cur = join(downloaded_data_dir, tpe)
+ for p in subdirs(cur, join=False):
+ patdir = join(cur, p)
+ patient_name = tpe + "__" + p
+ patient_names.append(patient_name)
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+ seg = join(patdir, p + "_seg.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ isfile(seg)
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTr, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTr, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTr, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTr, patient_name + "_0003.nii.gz"))
+
+ copy_BraTS_segmentation_and_convert_labels(seg, join(target_labelsTr, patient_name + ".nii.gz"))
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "BraTS2019"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see BraTS2019"
+ json_dict['licence'] = "see BraTS2019 license"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "T1",
+ "1": "T1ce",
+ "2": "T2",
+ "3": "FLAIR"
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "edema",
+ "2": "non-enhancing",
+ "3": "enhancing",
+ }
+ json_dict['numTraining'] = len(patient_names)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ patient_names]
+ json_dict['test'] = []
+
+ save_json(json_dict, join(target_base, "dataset.json"))
+
+ downloaded_data_dir = "/home/sdp/MLPERF/Brats2019_DATA/MICCAI_BraTS_2019_Data_Validation"
+
+ for p in subdirs(downloaded_data_dir, join=False):
+ patdir = join(downloaded_data_dir, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesVal, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesVal, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesVal, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesVal, patient_name + "_0003.nii.gz"))
+
+ """
+ #I dont have the testing data
+ downloaded_data_dir = "/home/fabian/Downloads/BraTS2018_train_val_test_data/MICCAI_BraTS_2018_Data_Testing_FIsensee"
+
+ for p in subdirs(downloaded_data_dir, join=False):
+ patdir = join(downloaded_data_dir, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTs, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTs, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTs, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTs, patient_name + "_0003.nii.gz"))"""
diff --git a/nnunet/dataset_conversion/Task055_SegTHOR.py b/nnunet/dataset_conversion/Task055_SegTHOR.py
new file mode 100644
index 0000000000000000000000000000000000000000..656764e12b407e194ba6673f7ad002cf105f0029
--- /dev/null
+++ b/nnunet/dataset_conversion/Task055_SegTHOR.py
@@ -0,0 +1,98 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from nnunet.paths import nnUNet_raw_data
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+import SimpleITK as sitk
+
+
+def convert_for_submission(source_dir, target_dir):
+ """
+ I believe they want .nii, not .nii.gz
+ :param source_dir:
+ :param target_dir:
+ :return:
+ """
+ files = subfiles(source_dir, suffix=".nii.gz", join=False)
+ maybe_mkdir_p(target_dir)
+ for f in files:
+ img = sitk.ReadImage(join(source_dir, f))
+ out_file = join(target_dir, f[:-7] + ".nii")
+ sitk.WriteImage(img, out_file)
+
+
+
+if __name__ == "__main__":
+ base = "/media/fabian/DeepLearningData/SegTHOR"
+
+ task_id = 55
+ task_name = "SegTHOR"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ train_patients = subfolders(join(base, "train"), join=False)
+ for p in train_patients:
+ curr = join(base, "train", p)
+ label_file = join(curr, "GT.nii.gz")
+ image_file = join(curr, p + ".nii.gz")
+ shutil.copy(image_file, join(imagestr, p + "_0000.nii.gz"))
+ shutil.copy(label_file, join(labelstr, p + ".nii.gz"))
+ train_patient_names.append(p)
+
+ test_patients = subfiles(join(base, "test"), join=False, suffix=".nii.gz")
+ for p in test_patients:
+ p = p[:-7]
+ curr = join(base, "test")
+ image_file = join(curr, p + ".nii.gz")
+ shutil.copy(image_file, join(imagests, p + "_0000.nii.gz"))
+ test_patient_names.append(p)
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "SegTHOR"
+ json_dict['description'] = "SegTHOR"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "esophagus",
+ "2": "heart",
+ "3": "trachea",
+ "4": "aorta",
+ }
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task056_VerSe2019.py b/nnunet/dataset_conversion/Task056_VerSe2019.py
new file mode 100644
index 0000000000000000000000000000000000000000..4962ec9ae634b319821199d08b665e83e44b2367
--- /dev/null
+++ b/nnunet/dataset_conversion/Task056_VerSe2019.py
@@ -0,0 +1,274 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+import SimpleITK as sitk
+from multiprocessing.pool import Pool
+from nnunet.configuration import default_num_threads
+from nnunet.paths import nnUNet_raw_data
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from medpy import metric
+import numpy as np
+from nnunet.utilities.image_reorientation import reorient_all_images_in_folder_to_ras
+
+
+def check_if_all_in_good_orientation(imagesTr_folder: str, labelsTr_folder: str, output_folder: str) -> None:
+ maybe_mkdir_p(output_folder)
+ filenames = subfiles(labelsTr_folder, suffix='.nii.gz', join=False)
+ import matplotlib.pyplot as plt
+ for n in filenames:
+ img = sitk.GetArrayFromImage(sitk.ReadImage(join(imagesTr_folder, n[:-7] + '_0000.nii.gz')))
+ lab = sitk.GetArrayFromImage(sitk.ReadImage(join(labelsTr_folder, n)))
+ assert np.all([i == j for i, j in zip(img.shape, lab.shape)])
+ z_slice = img.shape[0] // 2
+ img_slice = img[z_slice]
+ lab_slice = lab[z_slice]
+ lab_slice[lab_slice != 0] = 1
+ img_slice = img_slice - img_slice.min()
+ img_slice = img_slice / img_slice.max()
+ stacked = np.vstack((img_slice, lab_slice))
+ print(stacked.shape)
+ plt.imsave(join(output_folder, n[:-7] + '.png'), stacked, cmap='gray')
+
+
+def evaluate_verse_case(sitk_file_ref:str, sitk_file_test:str):
+ """
+ Only vertebra that are present in the reference will be evaluated
+ :param sitk_file_ref:
+ :param sitk_file_test:
+ :return:
+ """
+ gt_npy = sitk.GetArrayFromImage(sitk.ReadImage(sitk_file_ref))
+ pred_npy = sitk.GetArrayFromImage(sitk.ReadImage(sitk_file_test))
+ dice_scores = []
+ for label in range(1, 26):
+ mask_gt = gt_npy == label
+ if np.sum(mask_gt) > 0:
+ mask_pred = pred_npy == label
+ dc = metric.dc(mask_pred, mask_gt)
+ else:
+ dc = np.nan
+ dice_scores.append(dc)
+ return dice_scores
+
+
+def evaluate_verse_folder(folder_pred, folder_gt, out_json="/home/fabian/verse.json"):
+ p = Pool(default_num_threads)
+ files_gt_bare = subfiles(folder_gt, join=False)
+ assert all([isfile(join(folder_pred, i)) for i in files_gt_bare]), "some files are missing in the predicted folder"
+ files_pred = [join(folder_pred, i) for i in files_gt_bare]
+ files_gt = [join(folder_gt, i) for i in files_gt_bare]
+
+ results = p.starmap_async(evaluate_verse_case, zip(files_gt, files_pred))
+
+ results = results.get()
+
+ dct = {i: j for i, j in zip(files_gt_bare, results)}
+
+ results_stacked = np.vstack(results)
+ results_mean = np.nanmean(results_stacked, 0)
+ overall_mean = np.nanmean(results_mean)
+
+ save_json((dct, list(results_mean), overall_mean), out_json)
+ p.close()
+ p.join()
+
+
+def print_unique_labels_and_their_volumes(image: str, print_only_if_vol_smaller_than: float = None):
+ img = sitk.ReadImage(image)
+ voxel_volume = np.prod(img.GetSpacing())
+ img_npy = sitk.GetArrayFromImage(img)
+ uniques = [i for i in np.unique(img_npy) if i != 0]
+ volumes = {i: np.sum(img_npy == i) * voxel_volume for i in uniques}
+ print('')
+ print(image.split('/')[-1])
+ print('uniques:', uniques)
+ for k in volumes.keys():
+ v = volumes[k]
+ if print_only_if_vol_smaller_than is not None and v > print_only_if_vol_smaller_than:
+ pass
+ else:
+ print('k:', k, '\tvol:', volumes[k])
+
+
+def remove_label(label_file: str, remove_this: int, replace_with: int = 0):
+ img = sitk.ReadImage(label_file)
+ img_npy = sitk.GetArrayFromImage(img)
+ img_npy[img_npy == remove_this] = replace_with
+ img2 = sitk.GetImageFromArray(img_npy)
+ img2.CopyInformation(img)
+ sitk.WriteImage(img2, label_file)
+
+
+if __name__ == "__main__":
+ ### First we create a nnunet dataset from verse. After this the images will be all willy nilly in their
+ # orientation because that's how VerSe comes
+ base = '/media/fabian/DeepLearningData/VerSe2019'
+ base = "/home/fabian/data/VerSe2019"
+
+ # correct orientation
+ train_files_base = subfiles(join(base, "train"), join=False, suffix="_seg.nii.gz")
+ train_segs = [i[:-len("_seg.nii.gz")] + "_seg.nii.gz" for i in train_files_base]
+ train_data = [i[:-len("_seg.nii.gz")] + ".nii.gz" for i in train_files_base]
+ test_files_base = [i[:-len(".nii.gz")] for i in subfiles(join(base, "test"), join=False, suffix=".nii.gz")]
+ test_data = [i + ".nii.gz" for i in test_files_base]
+
+ task_id = 56
+ task_name = "VerSe"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = [i[:-len("_seg.nii.gz")] for i in subfiles(join(base, "train"), join=False, suffix="_seg.nii.gz")]
+ for p in train_patient_names:
+ curr = join(base, "train")
+ label_file = join(curr, p + "_seg.nii.gz")
+ image_file = join(curr, p + ".nii.gz")
+ shutil.copy(image_file, join(imagestr, p + "_0000.nii.gz"))
+ shutil.copy(label_file, join(labelstr, p + ".nii.gz"))
+
+ test_patient_names = [i[:-7] for i in subfiles(join(base, "test"), join=False, suffix=".nii.gz")]
+ for p in test_patient_names:
+ curr = join(base, "test")
+ image_file = join(curr, p + ".nii.gz")
+ shutil.copy(image_file, join(imagests, p + "_0000.nii.gz"))
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "VerSe2019"
+ json_dict['description'] = "VerSe2019"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {i: str(i) for i in range(26)}
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+
+ # now we reorient all those images to ras. This saves a pkl with the original affine. We need this information to
+ # bring our predictions into the same geometry for submission
+ reorient_all_images_in_folder_to_ras(imagestr)
+ reorient_all_images_in_folder_to_ras(imagests)
+ reorient_all_images_in_folder_to_ras(labelstr)
+
+ # sanity check
+ check_if_all_in_good_orientation(imagestr, labelstr, join(out_base, 'sanitycheck'))
+ # looks good to me - proceed
+
+ # check the volumes of the vertebrae
+ _ = [print_unique_labels_and_their_volumes(i, 1000) for i in subfiles(labelstr, suffix='.nii.gz')]
+
+ # some cases appear fishy. For example, verse063.nii.gz has labels [1, 20, 21, 22, 23, 24] and 1 only has a volume
+ # of 63mm^3
+
+ #let's correct those
+
+ # 19 is connected to the image border and should not be segmented. Only one slice of 19 is segmented in the
+ # reference. Looks wrong
+ remove_label(join(labelstr, 'verse031.nii.gz'), 19, 0)
+
+ # spurious annotation of 18 (vol: 8.00)
+ remove_label(join(labelstr, 'verse060.nii.gz'), 18, 0)
+
+ # spurious annotation of 16 (vol: 3.00)
+ remove_label(join(labelstr, 'verse061.nii.gz'), 16, 0)
+
+ # spurious annotation of 1 (vol: 63.00) although the rest of the vertebra is [20, 21, 22, 23, 24]
+ remove_label(join(labelstr, 'verse063.nii.gz'), 1, 0)
+
+ # spurious annotation of 3 (vol: 9.53) although the rest of the vertebra is
+ # [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
+ remove_label(join(labelstr, 'verse074.nii.gz'), 3, 0)
+
+ # spurious annotation of 3 (vol: 15.00)
+ remove_label(join(labelstr, 'verse097.nii.gz'), 3, 0)
+
+ # spurious annotation of 3 (vol: 10) although the rest of the vertebra is
+ # [8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24]
+ remove_label(join(labelstr, 'verse151.nii.gz'), 3, 0)
+
+ # spurious annotation of 25 (vol: 4) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ remove_label(join(labelstr, 'verse201.nii.gz'), 25, 0)
+
+ # spurious annotation of 23 (vol: 8) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8]
+ remove_label(join(labelstr, 'verse207.nii.gz'), 23, 0)
+
+ # spurious annotation of 23 (vol: 12) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ remove_label(join(labelstr, 'verse208.nii.gz'), 23, 0)
+
+ # spurious annotation of 23 (vol: 2) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ remove_label(join(labelstr, 'verse212.nii.gz'), 23, 0)
+
+ # spurious annotation of 20 (vol: 4) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ remove_label(join(labelstr, 'verse214.nii.gz'), 20, 0)
+
+ # spurious annotation of 23 (vol: 15) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8]
+ remove_label(join(labelstr, 'verse223.nii.gz'), 23, 0)
+
+ # spurious annotation of 23 (vol: 1) and 25 (vol: 7) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8, 9]
+ remove_label(join(labelstr, 'verse226.nii.gz'), 23, 0)
+ remove_label(join(labelstr, 'verse226.nii.gz'), 25, 0)
+
+ # spurious annotation of 25 (vol: 27) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8]
+ remove_label(join(labelstr, 'verse227.nii.gz'), 25, 0)
+
+ # spurious annotation of 20 (vol: 24) although the rest of the vertebra is
+ # [1, 2, 3, 4, 5, 6, 7, 8]
+ remove_label(join(labelstr, 'verse232.nii.gz'), 20, 0)
+
+
+ # Now we are ready to run nnU-Net
+
+
+ """# run this part of the code once training is done
+ folder_gt = "/media/fabian/My Book/MedicalDecathlon/nnUNet_raw_splitted/Task056_VerSe/labelsTr"
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_fullres/Task056_VerSe/nnUNetTrainerV2__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_fullres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_lowres/Task056_VerSe/nnUNetTrainerV2__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_lowres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_cascade_fullres/Task056_VerSe/nnUNetTrainerV2CascadeFullRes__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_cascade_fullres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)"""
+
diff --git a/nnunet/dataset_conversion/Task056_Verse_normalize_orientation.py b/nnunet/dataset_conversion/Task056_Verse_normalize_orientation.py
new file mode 100644
index 0000000000000000000000000000000000000000..61988d4a2d0664cfdee7aded2ecf7d8de6ad62e1
--- /dev/null
+++ b/nnunet/dataset_conversion/Task056_Verse_normalize_orientation.py
@@ -0,0 +1,98 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+"""
+This code is copied from https://gist.github.com/nlessmann/24d405eaa82abba6676deb6be839266c. All credits go to the
+original author (user nlessmann on GitHub)
+"""
+
+import numpy as np
+import SimpleITK as sitk
+
+
+def reverse_axes(image):
+ return np.transpose(image, tuple(reversed(range(image.ndim))))
+
+
+def read_image(imagefile):
+ image = sitk.ReadImage(imagefile)
+ data = reverse_axes(sitk.GetArrayFromImage(image)) # switch from zyx to xyz
+ header = {
+ 'spacing': image.GetSpacing(),
+ 'origin': image.GetOrigin(),
+ 'direction': image.GetDirection()
+ }
+ return data, header
+
+
+def save_image(img: np.ndarray, header: dict, output_file: str):
+ """
+ CAREFUL you need to restore_original_slice_orientation before saving!
+ :param img:
+ :param header:
+ :return:
+ """
+ # reverse back
+ img = reverse_axes(img) # switch from zyx to xyz
+ img_itk = sitk.GetImageFromArray(img)
+ img_itk.SetSpacing(header['spacing'])
+ img_itk.SetOrigin(header['origin'])
+ if not isinstance(header['direction'], tuple):
+ img_itk.SetDirection(header['direction'].flatten())
+ else:
+ img_itk.SetDirection(header['direction'])
+
+ sitk.WriteImage(img_itk, output_file)
+
+
+def swap_flip_dimensions(cosine_matrix, image, header=None):
+ # Compute swaps and flips
+ swap = np.argmax(abs(cosine_matrix), axis=0)
+ flip = np.sum(cosine_matrix, axis=0)
+
+ # Apply transformation to image volume
+ image = np.transpose(image, tuple(swap))
+ image = image[tuple(slice(None, None, int(f)) for f in flip)]
+
+ if header is None:
+ return image
+
+ # Apply transformation to header
+ header['spacing'] = tuple(header['spacing'][s] for s in swap)
+ header['direction'] = np.eye(3)
+
+ return image, header
+
+
+def normalize_slice_orientation(image, header):
+ # Preserve original header so that we can easily transform back
+ header['original'] = header.copy()
+
+ # Compute inverse of cosine (round first because we assume 0/1 values only)
+ # to determine how the image has to be transposed and flipped for cosine = identity
+ cosine = np.asarray(header['direction']).reshape(3, 3)
+ cosine_inv = np.linalg.inv(np.round(cosine))
+
+ return swap_flip_dimensions(cosine_inv, image, header)
+
+
+def restore_original_slice_orientation(mask, header):
+ # Use original orientation for transformation because we assume the image to be in
+ # normalized orientation, i.e., identity cosine)
+ cosine = np.asarray(header['original']['direction']).reshape(3, 3)
+ cosine_rnd = np.round(cosine)
+
+ # Apply transformations to both the image and the mask
+ return swap_flip_dimensions(cosine_rnd, mask), header['original']
diff --git a/nnunet/dataset_conversion/Task058_ISBI_EM_SEG.py b/nnunet/dataset_conversion/Task058_ISBI_EM_SEG.py
new file mode 100644
index 0000000000000000000000000000000000000000..998dc6dbf604f2227ebfb791ecffa387fe093ce5
--- /dev/null
+++ b/nnunet/dataset_conversion/Task058_ISBI_EM_SEG.py
@@ -0,0 +1,105 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+from skimage import io
+
+
+def export_for_submission(predicted_npz, out_file):
+ """
+ they expect us to submit a 32 bit 3d tif image with values between 0 (100% membrane certainty) and 1
+ (100% non-membrane certainty). We use the softmax output for that
+ :return:
+ """
+ a = np.load(predicted_npz)['softmax']
+ a = a / a.sum(0)[None]
+ # channel 0 is non-membrane prob
+ nonmembr_prob = a[0]
+ assert out_file.endswith(".tif")
+ io.imsave(out_file, nonmembr_prob.astype(np.float32))
+
+
+
+if __name__ == "__main__":
+ # download from here http://brainiac2.mit.edu/isbi_challenge/downloads
+
+ base = "/media/fabian/My Book/datasets/ISBI_EM_SEG"
+ # the orientation of VerSe is all fing over the place. run fslreorient2std to correct that (hopefully!)
+ # THIS CAN HAVE CONSEQUENCES FOR THE TEST SET SUBMISSION! CAREFUL!
+ train_volume = io.imread(join(base, "train-volume.tif"))
+ train_labels = io.imread(join(base, "train-labels.tif"))
+ train_labels[train_labels == 255] = 1
+ test_volume = io.imread(join(base, "test-volume.tif"))
+
+ task_id = 58
+ task_name = "ISBI_EM_SEG"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ img_tr_itk = sitk.GetImageFromArray(train_volume.astype(np.float32))
+ lab_tr_itk = sitk.GetImageFromArray(1 - train_labels) # walls are foreground, cells background
+ img_te_itk = sitk.GetImageFromArray(test_volume.astype(np.float32))
+
+ img_tr_itk.SetSpacing((4, 4, 50))
+ lab_tr_itk.SetSpacing((4, 4, 50))
+ img_te_itk.SetSpacing((4, 4, 50))
+
+ # 5 copies, otherwise we cannot run nnunet (5 fold cv needs that)
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training0_0000.nii.gz"))
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training1_0000.nii.gz"))
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training2_0000.nii.gz"))
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training3_0000.nii.gz"))
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training4_0000.nii.gz"))
+
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training0.nii.gz"))
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training1.nii.gz"))
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training2.nii.gz"))
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training3.nii.gz"))
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training4.nii.gz"))
+
+ sitk.WriteImage(img_te_itk, join(imagests, "testing.nii.gz"))
+
+ json_dict = OrderedDict()
+ json_dict['name'] = task_name
+ json_dict['description'] = task_name
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "EM",
+ }
+ json_dict['labels'] = {i: str(i) for i in range(2)}
+
+ json_dict['numTraining'] = 5
+ json_dict['numTest'] = 1
+ json_dict['training'] = [{'image': "./imagesTr/training%d.nii.gz" % i, "label": "./labelsTr/training%d.nii.gz" % i} for i in
+ range(5)]
+ json_dict['test'] = ["./imagesTs/testing.nii.gz"]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task059_EPFL_EM_MITO_SEG.py b/nnunet/dataset_conversion/Task059_EPFL_EM_MITO_SEG.py
new file mode 100644
index 0000000000000000000000000000000000000000..e70edfd9d6563f6cb4a1b472e5cab109b14d8c9d
--- /dev/null
+++ b/nnunet/dataset_conversion/Task059_EPFL_EM_MITO_SEG.py
@@ -0,0 +1,99 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import subprocess
+from collections import OrderedDict
+from nnunet.paths import nnUNet_raw_data
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from skimage import io
+import SimpleITK as sitk
+import shutil
+
+
+if __name__ == "__main__":
+ # download from here https://www.epfl.ch/labs/cvlab/data/data-em/
+
+ base = "/media/fabian/My Book/datasets/EPFL_MITO_SEG"
+ # the orientation of VerSe is all fing over the place. run fslreorient2std to correct that (hopefully!)
+ # THIS CAN HAVE CONSEQUENCES FOR THE TEST SET SUBMISSION! CAREFUL!
+ train_volume = io.imread(join(base, "training.tif"))
+ train_labels = io.imread(join(base, "training_groundtruth.tif"))
+ train_labels[train_labels == 255] = 1
+ test_volume = io.imread(join(base, "testing.tif"))
+ test_labels = io.imread(join(base, "testing_groundtruth.tif"))
+ test_labels[test_labels == 255] = 1
+
+ task_id = 59
+ task_name = "EPFL_EM_MITO_SEG"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ labelste = join(out_base, "labelsTs")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+ maybe_mkdir_p(labelste)
+
+ img_tr_itk = sitk.GetImageFromArray(train_volume.astype(np.float32))
+ lab_tr_itk = sitk.GetImageFromArray(train_labels.astype(np.uint8))
+ img_te_itk = sitk.GetImageFromArray(test_volume.astype(np.float32))
+ lab_te_itk = sitk.GetImageFromArray(test_labels.astype(np.uint8))
+
+ img_tr_itk.SetSpacing((5, 5, 5))
+ lab_tr_itk.SetSpacing((5, 5, 5))
+ img_te_itk.SetSpacing((5, 5, 5))
+ lab_te_itk.SetSpacing((5, 5, 5))
+
+ # 5 copies, otherwise we cannot run nnunet (5 fold cv needs that)
+ sitk.WriteImage(img_tr_itk, join(imagestr, "training0_0000.nii.gz"))
+ shutil.copy(join(imagestr, "training0_0000.nii.gz"), join(imagestr, "training1_0000.nii.gz"))
+ shutil.copy(join(imagestr, "training0_0000.nii.gz"), join(imagestr, "training2_0000.nii.gz"))
+ shutil.copy(join(imagestr, "training0_0000.nii.gz"), join(imagestr, "training3_0000.nii.gz"))
+ shutil.copy(join(imagestr, "training0_0000.nii.gz"), join(imagestr, "training4_0000.nii.gz"))
+
+ sitk.WriteImage(lab_tr_itk, join(labelstr, "training0.nii.gz"))
+ shutil.copy(join(labelstr, "training0.nii.gz"), join(labelstr, "training1.nii.gz"))
+ shutil.copy(join(labelstr, "training0.nii.gz"), join(labelstr, "training2.nii.gz"))
+ shutil.copy(join(labelstr, "training0.nii.gz"), join(labelstr, "training3.nii.gz"))
+ shutil.copy(join(labelstr, "training0.nii.gz"), join(labelstr, "training4.nii.gz"))
+
+ sitk.WriteImage(img_te_itk, join(imagests, "testing.nii.gz"))
+ sitk.WriteImage(lab_te_itk, join(labelste, "testing.nii.gz"))
+
+ json_dict = OrderedDict()
+ json_dict['name'] = task_name
+ json_dict['description'] = task_name
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "EM",
+ }
+ json_dict['labels'] = {i: str(i) for i in range(2)}
+
+ json_dict['numTraining'] = 5
+ json_dict['numTest'] = 1
+ json_dict['training'] = [{'image': "./imagesTr/training%d.nii.gz" % i, "label": "./labelsTr/training%d.nii.gz" % i} for i in
+ range(5)]
+ json_dict['test'] = ["./imagesTs/testing.nii.gz"]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task061_CREMI.py b/nnunet/dataset_conversion/Task061_CREMI.py
new file mode 100644
index 0000000000000000000000000000000000000000..916396441972c8f71b13b77cda98c9cf887f3c36
--- /dev/null
+++ b/nnunet/dataset_conversion/Task061_CREMI.py
@@ -0,0 +1,146 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+from nnunet.paths import nnUNet_raw_data, preprocessing_output_dir
+import shutil
+import SimpleITK as sitk
+
+try:
+ import h5py
+except ImportError:
+ h5py = None
+
+
+def load_sample(filename):
+ # we need raw data and seg
+ f = h5py.File(filename, 'r')
+ data = np.array(f['volumes']['raw'])
+
+ if 'labels' in f['volumes'].keys():
+ labels = np.array(f['volumes']['labels']['clefts'])
+ # clefts are low values, background is high
+ labels = (labels < 100000).astype(np.uint8)
+ else:
+ labels = None
+ return data, labels
+
+
+def save_as_nifti(arr, filename, spacing):
+ itk_img = sitk.GetImageFromArray(arr)
+ itk_img.SetSpacing(spacing)
+ sitk.WriteImage(itk_img, filename)
+
+
+def prepare_submission():
+ from cremi.io import CremiFile
+ from cremi.Volume import Volume
+
+ base = "/home/fabian/drives/datasets/results/nnUNet/test_sets/Task061_CREMI/"
+ # a+
+ pred = sitk.GetArrayFromImage(sitk.ReadImage(join(base, 'results_3d_fullres', "sample_a+.nii.gz"))).astype(np.uint64)
+ pred[pred == 0] = 0xffffffffffffffff
+ out_a = CremiFile(join(base, 'sample_A+_20160601.hdf'), 'w')
+ clefts = Volume(pred, (40., 4., 4.))
+ out_a.write_clefts(clefts)
+ out_a.close()
+
+ pred = sitk.GetArrayFromImage(sitk.ReadImage(join(base, 'results_3d_fullres', "sample_b+.nii.gz"))).astype(np.uint64)
+ pred[pred == 0] = 0xffffffffffffffff
+ out_b = CremiFile(join(base, 'sample_B+_20160601.hdf'), 'w')
+ clefts = Volume(pred, (40., 4., 4.))
+ out_b.write_clefts(clefts)
+ out_b.close()
+
+ pred = sitk.GetArrayFromImage(sitk.ReadImage(join(base, 'results_3d_fullres', "sample_c+.nii.gz"))).astype(np.uint64)
+ pred[pred == 0] = 0xffffffffffffffff
+ out_c = CremiFile(join(base, 'sample_C+_20160601.hdf'), 'w')
+ clefts = Volume(pred, (40., 4., 4.))
+ out_c.write_clefts(clefts)
+ out_c.close()
+
+
+if __name__ == "__main__":
+ assert h5py is not None, "you need h5py for this. Install with 'pip install h5py'"
+
+ foldername = "Task061_CREMI"
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ base = "/media/fabian/My Book/datasets/CREMI"
+
+ # train
+ img, label = load_sample(join(base, "sample_A_20160501.hdf"))
+ save_as_nifti(img, join(imagestr, "sample_a_0000.nii.gz"), (4, 4, 40))
+ save_as_nifti(label, join(labelstr, "sample_a.nii.gz"), (4, 4, 40))
+ img, label = load_sample(join(base, "sample_B_20160501.hdf"))
+ save_as_nifti(img, join(imagestr, "sample_b_0000.nii.gz"), (4, 4, 40))
+ save_as_nifti(label, join(labelstr, "sample_b.nii.gz"), (4, 4, 40))
+ img, label = load_sample(join(base, "sample_C_20160501.hdf"))
+ save_as_nifti(img, join(imagestr, "sample_c_0000.nii.gz"), (4, 4, 40))
+ save_as_nifti(label, join(labelstr, "sample_c.nii.gz"), (4, 4, 40))
+
+ save_as_nifti(img, join(imagestr, "sample_d_0000.nii.gz"), (4, 4, 40))
+ save_as_nifti(label, join(labelstr, "sample_d.nii.gz"), (4, 4, 40))
+
+ save_as_nifti(img, join(imagestr, "sample_e_0000.nii.gz"), (4, 4, 40))
+ save_as_nifti(label, join(labelstr, "sample_e.nii.gz"), (4, 4, 40))
+
+ # test
+ img, label = load_sample(join(base, "sample_A+_20160601.hdf"))
+ save_as_nifti(img, join(imagests, "sample_a+_0000.nii.gz"), (4, 4, 40))
+ img, label = load_sample(join(base, "sample_B+_20160601.hdf"))
+ save_as_nifti(img, join(imagests, "sample_b+_0000.nii.gz"), (4, 4, 40))
+ img, label = load_sample(join(base, "sample_C+_20160601.hdf"))
+ save_as_nifti(img, join(imagests, "sample_c+_0000.nii.gz"), (4, 4, 40))
+
+ json_dict = OrderedDict()
+ json_dict['name'] = foldername
+ json_dict['description'] = foldername
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "EM",
+ }
+ json_dict['labels'] = {i: str(i) for i in range(2)}
+
+ json_dict['numTraining'] = 5
+ json_dict['numTest'] = 1
+ json_dict['training'] = [{'image': "./imagesTr/sample_%s.nii.gz" % i, "label": "./labelsTr/sample_%s.nii.gz" % i} for i in
+ ['a', 'b', 'c', 'd', 'e']]
+
+ json_dict['test'] = ["./imagesTs/sample_a+.nii.gz", "./imagesTs/sample_b+.nii.gz", "./imagesTs/sample_c+.nii.gz"]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+
+ out_preprocessed = join(preprocessing_output_dir, foldername)
+ maybe_mkdir_p(out_preprocessed)
+ # manual splits. we train 5 models on all three datasets
+ splits = [{'train': ["sample_a", "sample_b", "sample_c"], 'val': ["sample_a", "sample_b", "sample_c"]},
+ {'train': ["sample_a", "sample_b", "sample_c"], 'val': ["sample_a", "sample_b", "sample_c"]},
+ {'train': ["sample_a", "sample_b", "sample_c"], 'val': ["sample_a", "sample_b", "sample_c"]},
+ {'train': ["sample_a", "sample_b", "sample_c"], 'val': ["sample_a", "sample_b", "sample_c"]},
+ {'train': ["sample_a", "sample_b", "sample_c"], 'val': ["sample_a", "sample_b", "sample_c"]}]
+ save_pickle(splits, join(out_preprocessed, "splits_final.pkl"))
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task062_NIHPancreas.py b/nnunet/dataset_conversion/Task062_NIHPancreas.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e29e6e052c0c0ff1ba9a537620ea62c9c9b9dad
--- /dev/null
+++ b/nnunet/dataset_conversion/Task062_NIHPancreas.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from nnunet.paths import nnUNet_raw_data
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from multiprocessing import Pool
+import nibabel
+
+
+def reorient(filename):
+ img = nibabel.load(filename)
+ img = nibabel.as_closest_canonical(img)
+ nibabel.save(img, filename)
+
+
+if __name__ == "__main__":
+ base = "/media/fabian/DeepLearningData/Pancreas-CT"
+
+ # reorient
+ p = Pool(8)
+ results = []
+
+ for f in subfiles(join(base, "data"), suffix=".nii.gz"):
+ results.append(p.map_async(reorient, (f, )))
+ _ = [i.get() for i in results]
+
+ for f in subfiles(join(base, "TCIA_pancreas_labels-02-05-2017"), suffix=".nii.gz"):
+ results.append(p.map_async(reorient, (f, )))
+ _ = [i.get() for i in results]
+
+ task_id = 62
+ task_name = "NIHPancreas"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ cases = list(range(1, 83))
+ folder_data = join(base, "data")
+ folder_labels = join(base, "TCIA_pancreas_labels-02-05-2017")
+ for c in cases:
+ casename = "pancreas_%04.0d" % c
+ shutil.copy(join(folder_data, "PANCREAS_%04.0d.nii.gz" % c), join(imagestr, casename + "_0000.nii.gz"))
+ shutil.copy(join(folder_labels, "label%04.0d.nii.gz" % c), join(labelstr, casename + ".nii.gz"))
+ train_patient_names.append(casename)
+
+ json_dict = OrderedDict()
+ json_dict['name'] = task_name
+ json_dict['description'] = task_name
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see website"
+ json_dict['licence'] = "see website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "Pancreas",
+ }
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task064_KiTS_labelsFixed.py b/nnunet/dataset_conversion/Task064_KiTS_labelsFixed.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa0ee30d69dfd3236273ee7bb8eae873ec63a309
--- /dev/null
+++ b/nnunet/dataset_conversion/Task064_KiTS_labelsFixed.py
@@ -0,0 +1,84 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+
+
+if __name__ == "__main__":
+ """
+ This is the KiTS dataset after Nick fixed all the labels that had errors. Downloaded on Jan 6th 2020
+ """
+
+ base = "/media/fabian/My Book/datasets/KiTS_clean/kits19/data"
+
+ task_id = 64
+ task_name = "KiTS_labelsFixed"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ all_cases = subfolders(base, join=False)
+
+ train_patients = all_cases[:210]
+ test_patients = all_cases[210:]
+
+ for p in train_patients:
+ curr = join(base, p)
+ label_file = join(curr, "segmentation.nii.gz")
+ image_file = join(curr, "imaging.nii.gz")
+ shutil.copy(image_file, join(imagestr, p + "_0000.nii.gz"))
+ shutil.copy(label_file, join(labelstr, p + ".nii.gz"))
+ train_patient_names.append(p)
+
+ for p in test_patients:
+ curr = join(base, p)
+ image_file = join(curr, "imaging.nii.gz")
+ shutil.copy(image_file, join(imagests, p + "_0000.nii.gz"))
+ test_patient_names.append(p)
+
+ json_dict = {}
+ json_dict['name'] = "KiTS"
+ json_dict['description'] = "kidney and kidney tumor segmentation"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "KiTS data for nnunet"
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "Kidney",
+ "2": "Tumor"
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task065_KiTS_NicksLabels.py b/nnunet/dataset_conversion/Task065_KiTS_NicksLabels.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf5cb147006d00b6ae5c4778ab74c86f3b775c90
--- /dev/null
+++ b/nnunet/dataset_conversion/Task065_KiTS_NicksLabels.py
@@ -0,0 +1,87 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+
+if __name__ == "__main__":
+ """
+ Nick asked me to rerun the training with other labels (the Kidney region is defined differently).
+
+ These labels operate in interpolated spacing. I don't like that but that's how it is
+ """
+
+ base = "/media/fabian/My Book/datasets/KiTS_NicksLabels/kits19/data"
+ labelsdir = "/media/fabian/My Book/datasets/KiTS_NicksLabels/filled_labels"
+
+ task_id = 65
+ task_name = "KiTS_NicksLabels"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ all_cases = subfolders(base, join=False)
+
+ train_patients = all_cases[:210]
+ test_patients = all_cases[210:]
+
+ for p in train_patients:
+ curr = join(base, p)
+ label_file = join(labelsdir, p + ".nii.gz")
+ image_file = join(curr, "imaging.nii.gz")
+ shutil.copy(image_file, join(imagestr, p + "_0000.nii.gz"))
+ shutil.copy(label_file, join(labelstr, p + ".nii.gz"))
+ train_patient_names.append(p)
+
+ for p in test_patients:
+ curr = join(base, p)
+ image_file = join(curr, "imaging.nii.gz")
+ shutil.copy(image_file, join(imagests, p + "_0000.nii.gz"))
+ test_patient_names.append(p)
+
+ json_dict = {}
+ json_dict['name'] = "KiTS"
+ json_dict['description'] = "kidney and kidney tumor segmentation"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "KiTS data for nnunet"
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "Kidney",
+ "2": "Tumor"
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task069_CovidSeg.py b/nnunet/dataset_conversion/Task069_CovidSeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..73e26f9764984423ae4634d6f394bb677489c63d
--- /dev/null
+++ b/nnunet/dataset_conversion/Task069_CovidSeg.py
@@ -0,0 +1,68 @@
+import shutil
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import SimpleITK as sitk
+from nnunet.paths import nnUNet_raw_data
+
+if __name__ == '__main__':
+ #data is available at http://medicalsegmentation.com/covid19/
+ download_dir = '/home/fabian/Downloads'
+
+ task_id = 69
+ task_name = "CovidSeg"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+
+ # the niftis are 3d, but they are just stacks of 2d slices from different patients. So no 3d U-Net, please
+
+ # the training stack has 100 slices, so we split it into 5 equally sized parts (20 slices each) for cross-validation
+ training_data = sitk.GetArrayFromImage(sitk.ReadImage(join(download_dir, 'tr_im.nii.gz')))
+ training_labels = sitk.GetArrayFromImage(sitk.ReadImage(join(download_dir, 'tr_mask.nii.gz')))
+
+ for f in range(5):
+ this_name = 'part_%d' % f
+ data = training_data[f::5]
+ labels = training_labels[f::5]
+ sitk.WriteImage(sitk.GetImageFromArray(data), join(imagestr, this_name + '_0000.nii.gz'))
+ sitk.WriteImage(sitk.GetImageFromArray(labels), join(labelstr, this_name + '.nii.gz'))
+ train_patient_names.append(this_name)
+
+ shutil.copy(join(download_dir, 'val_im.nii.gz'), join(imagests, 'val_im.nii.gz'))
+
+ test_patient_names.append('val_im')
+
+ json_dict = {}
+ json_dict['name'] = task_name
+ json_dict['description'] = ""
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = ""
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "nonct",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "stuff1",
+ "2": "stuff2",
+ "3": "stuff3",
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
diff --git a/nnunet/dataset_conversion/Task075_Fluo_C3DH_A549_ManAndSim.py b/nnunet/dataset_conversion/Task075_Fluo_C3DH_A549_ManAndSim.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eee2af5c1f9b8036f8c61323c722f12330d8dcf
--- /dev/null
+++ b/nnunet/dataset_conversion/Task075_Fluo_C3DH_A549_ManAndSim.py
@@ -0,0 +1,137 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from multiprocessing import Pool
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+from nnunet.paths import preprocessing_output_dir
+from skimage.io import imread
+
+
+def load_tiff_convert_to_nifti(img_file, lab_file, img_out_base, anno_out, spacing):
+ img = imread(img_file)
+ img_itk = sitk.GetImageFromArray(img.astype(np.float32))
+ img_itk.SetSpacing(np.array(spacing)[::-1])
+ sitk.WriteImage(img_itk, join(img_out_base + "_0000.nii.gz"))
+
+ if lab_file is not None:
+ l = imread(lab_file)
+ l[l > 0] = 1
+ l_itk = sitk.GetImageFromArray(l.astype(np.uint8))
+ l_itk.SetSpacing(np.array(spacing)[::-1])
+ sitk.WriteImage(l_itk, anno_out)
+
+
+def prepare_task(base, task_id, task_name, spacing):
+ p = Pool(16)
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ res = []
+
+ for train_sequence in [i for i in subfolders(base + "_train", join=False) if not i.endswith("_GT")]:
+ train_cases = subfiles(join(base + '_train', train_sequence), suffix=".tif", join=False)
+ for t in train_cases:
+ casename = train_sequence + "_" + t[:-4]
+ img_file = join(base + '_train', train_sequence, t)
+ lab_file = join(base + '_train', train_sequence + "_GT", "SEG", "man_seg" + t[1:])
+ if not isfile(lab_file):
+ continue
+ img_out_base = join(imagestr, casename)
+ anno_out = join(labelstr, casename + ".nii.gz")
+ res.append(
+ p.starmap_async(load_tiff_convert_to_nifti, ((img_file, lab_file, img_out_base, anno_out, spacing),)))
+ train_patient_names.append(casename)
+
+ for test_sequence in [i for i in subfolders(base + "_test", join=False) if not i.endswith("_GT")]:
+ test_cases = subfiles(join(base + '_test', test_sequence), suffix=".tif", join=False)
+ for t in test_cases:
+ casename = test_sequence + "_" + t[:-4]
+ img_file = join(base + '_test', test_sequence, t)
+ lab_file = None
+ img_out_base = join(imagests, casename)
+ anno_out = None
+ res.append(
+ p.starmap_async(load_tiff_convert_to_nifti, ((img_file, lab_file, img_out_base, anno_out, spacing),)))
+ test_patient_names.append(casename)
+
+ _ = [i.get() for i in res]
+
+ json_dict = {}
+ json_dict['name'] = task_name
+ json_dict['description'] = ""
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = ""
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "BF",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "cell",
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+ p.close()
+ p.join()
+
+
+if __name__ == "__main__":
+ base = "/media/fabian/My Book/datasets/CellTrackingChallenge/Fluo-C3DH-A549_ManAndSim"
+ task_id = 75
+ task_name = 'Fluo_C3DH_A549_ManAndSim'
+ spacing = (1, 0.126, 0.126)
+ prepare_task(base, task_id, task_name, spacing)
+
+ task_name = "Task075_Fluo_C3DH_A549_ManAndSim"
+ labelsTr = join(nnUNet_raw_data, task_name, "labelsTr")
+ cases = subfiles(labelsTr, suffix='.nii.gz', join=False)
+ splits = []
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('01_') or i.startswith('02_SIM')],
+ 'val': [i[:-7] for i in cases if i.startswith('02_') and not i.startswith('02_SIM')]}
+ )
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('02_') or i.startswith('01_SIM')],
+ 'val': [i[:-7] for i in cases if i.startswith('01_') and not i.startswith('01_SIM')]}
+ )
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('01_') or i.startswith('02_') and not i.startswith('02_SIM')],
+ 'val': [i[:-7] for i in cases if i.startswith('02_SIM')]}
+ )
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('02_') or i.startswith('01_') and not i.startswith('01_SIM')],
+ 'val': [i[:-7] for i in cases if i.startswith('01_SIM')]}
+ )
+ save_pickle(splits, join(preprocessing_output_dir, task_name, "splits_final.pkl"))
+
diff --git a/nnunet/dataset_conversion/Task076_Fluo_N3DH_SIM.py b/nnunet/dataset_conversion/Task076_Fluo_N3DH_SIM.py
new file mode 100644
index 0000000000000000000000000000000000000000..435592c5d6f10f3c15fcbe16a140aa06eecf5a00
--- /dev/null
+++ b/nnunet/dataset_conversion/Task076_Fluo_N3DH_SIM.py
@@ -0,0 +1,312 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from multiprocessing import Pool
+from multiprocessing.dummy import Pool
+
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from skimage.io import imread
+from skimage.io import imsave
+from skimage.morphology import ball
+from skimage.morphology import erosion
+from skimage.transform import resize
+
+from nnunet.paths import nnUNet_raw_data
+from nnunet.paths import preprocessing_output_dir
+
+
+def load_bmp_convert_to_nifti_borders(img_file, lab_file, img_out_base, anno_out, spacing, border_thickness=0.7):
+ img = imread(img_file)
+ img_itk = sitk.GetImageFromArray(img.astype(np.float32))
+ img_itk.SetSpacing(np.array(spacing)[::-1])
+ sitk.WriteImage(img_itk, join(img_out_base + "_0000.nii.gz"))
+
+ if lab_file is not None:
+ l = imread(lab_file)
+ borders = generate_border_as_suggested_by_twollmann(l, spacing, border_thickness)
+ l[l > 0] = 1
+ l[borders == 1] = 2
+ l_itk = sitk.GetImageFromArray(l.astype(np.uint8))
+ l_itk.SetSpacing(np.array(spacing)[::-1])
+ sitk.WriteImage(l_itk, anno_out)
+
+
+def generate_ball(spacing, radius, dtype=int):
+ radius_in_voxels = np.round(radius / np.array(spacing)).astype(int)
+ n = 2 * radius_in_voxels + 1
+ ball_iso = ball(max(n) * 2, dtype=np.float64)
+ ball_resampled = resize(ball_iso, n, 1, 'constant', 0, clip=True, anti_aliasing=False, preserve_range=True)
+ ball_resampled[ball_resampled > 0.5] = 1
+ ball_resampled[ball_resampled <= 0.5] = 0
+ return ball_resampled.astype(dtype)
+
+
+def generate_border_as_suggested_by_twollmann(label_img: np.ndarray, spacing, border_thickness: float = 2) -> np.ndarray:
+ border = np.zeros_like(label_img)
+ selem = generate_ball(spacing, border_thickness)
+ for l in np.unique(label_img):
+ if l == 0: continue
+ mask = (label_img == l).astype(int)
+ eroded = erosion(mask, selem)
+ border[(eroded == 0) & (mask != 0)] = 1
+ return border
+
+
+def find_differences(labelstr1, labelstr2):
+ for n in subfiles(labelstr1, suffix='.nii.gz', join=False):
+ a = sitk.GetArrayFromImage(sitk.ReadImage(join(labelstr1, n)))
+ b = sitk.GetArrayFromImage(sitk.ReadImage(join(labelstr2, n)))
+ print(n, np.sum(a != b))
+
+
+def prepare_task(base, task_id, task_name, spacing, border_thickness: float = 15, processes: int = 16):
+ p = Pool(processes)
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ res = []
+
+ for train_sequence in [i for i in subfolders(base + "_train", join=False) if not i.endswith("_GT")]:
+ train_cases = subfiles(join(base + '_train', train_sequence), suffix=".tif", join=False)
+ for t in train_cases:
+ casename = train_sequence + "_" + t[:-4]
+ img_file = join(base + '_train', train_sequence, t)
+ lab_file = join(base + '_train', train_sequence + "_GT", "SEG", "man_seg" + t[1:])
+ if not isfile(lab_file):
+ continue
+ img_out_base = join(imagestr, casename)
+ anno_out = join(labelstr, casename + ".nii.gz")
+ res.append(
+ p.starmap_async(load_bmp_convert_to_nifti_borders, ((img_file, lab_file, img_out_base, anno_out, spacing, border_thickness),)))
+ train_patient_names.append(casename)
+
+ for test_sequence in [i for i in subfolders(base + "_test", join=False) if not i.endswith("_GT")]:
+ test_cases = subfiles(join(base + '_test', test_sequence), suffix=".tif", join=False)
+ for t in test_cases:
+ casename = test_sequence + "_" + t[:-4]
+ img_file = join(base + '_test', test_sequence, t)
+ lab_file = None
+ img_out_base = join(imagests, casename)
+ anno_out = None
+ res.append(
+ p.starmap_async(load_bmp_convert_to_nifti_borders, ((img_file, lab_file, img_out_base, anno_out, spacing, border_thickness),)))
+ test_patient_names.append(casename)
+
+ _ = [i.get() for i in res]
+
+ json_dict = {}
+ json_dict['name'] = task_name
+ json_dict['description'] = ""
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = ""
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "BF",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "cell",
+ "2": "border",
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+ p.close()
+ p.join()
+
+
+def plot_images(folder, output_folder):
+ maybe_mkdir_p(output_folder)
+ import matplotlib.pyplot as plt
+ for i in subfiles(folder, suffix='.nii.gz', join=False):
+ img = sitk.GetArrayFromImage(sitk.ReadImage(join(folder, i)))
+ center_slice = img[img.shape[0]//2]
+ plt.imsave(join(output_folder, i[:-7] + '.png'), center_slice)
+
+
+def convert_to_tiff(nifti_image: str, output_name: str):
+ npy = sitk.GetArrayFromImage(sitk.ReadImage(nifti_image))
+ imsave(output_name, npy.astype(np.uint16), compress=6)
+
+
+def convert_to_instance_seg(arr: np.ndarray, spacing: tuple = (0.2, 0.125, 0.125)):
+ from skimage.morphology import label, dilation
+ # 1 is core, 2 is border
+ objects = label((arr == 1).astype(int))
+ final = np.copy(objects)
+ remaining_border = arr == 2
+ current = np.copy(objects)
+ dilated_mm = np.array((0, 0, 0))
+ spacing = np.array(spacing)
+
+ while np.sum(remaining_border) > 0:
+ strel_size = [0, 0, 0]
+ maximum_dilation = max(dilated_mm)
+ for i in range(3):
+ if spacing[i] == min(spacing):
+ strel_size[i] = 1
+ continue
+ if dilated_mm[i] + spacing[i] / 2 < maximum_dilation:
+ strel_size[i] = 1
+ ball_here = ball(1)
+
+ if strel_size[0] == 0: ball_here = ball_here[1:2]
+ if strel_size[1] == 0: ball_here = ball_here[:, 1:2]
+ if strel_size[2] == 0: ball_here = ball_here[:, :, 1:2]
+
+ #print(1)
+ dilated = dilation(current, ball_here)
+ diff = (current == 0) & (dilated != current)
+ final[diff & remaining_border] = dilated[diff & remaining_border]
+ remaining_border[diff] = 0
+ current = dilated
+ dilated_mm = [dilated_mm[i] + spacing[i] if strel_size[i] == 1 else dilated_mm[i] for i in range(3)]
+ return final.astype(np.uint32)
+
+
+def convert_to_instance_seg2(arr: np.ndarray, spacing: tuple = (0.2, 0.125, 0.125), small_center_threshold=30,
+ isolated_border_as_separate_instance_threshold: int = 15):
+ from skimage.morphology import label, dilation
+ # we first identify centers that are too small and set them to be border. This should remove false positive instances
+ objects = label((arr == 1).astype(int))
+ for o in np.unique(objects):
+ if o > 0 and np.sum(objects == o) <= small_center_threshold:
+ arr[objects == o] = 2
+
+ # 1 is core, 2 is border
+ objects = label((arr == 1).astype(int))
+ final = np.copy(objects)
+ remaining_border = arr == 2
+ current = np.copy(objects)
+ dilated_mm = np.array((0, 0, 0))
+ spacing = np.array(spacing)
+
+ while np.sum(remaining_border) > 0:
+ strel_size = [0, 0, 0]
+ maximum_dilation = max(dilated_mm)
+ for i in range(3):
+ if spacing[i] == min(spacing):
+ strel_size[i] = 1
+ continue
+ if dilated_mm[i] + spacing[i] / 2 < maximum_dilation:
+ strel_size[i] = 1
+ ball_here = ball(1)
+
+ if strel_size[0] == 0: ball_here = ball_here[1:2]
+ if strel_size[1] == 0: ball_here = ball_here[:, 1:2]
+ if strel_size[2] == 0: ball_here = ball_here[:, :, 1:2]
+
+ #print(1)
+ dilated = dilation(current, ball_here)
+ diff = (current == 0) & (dilated != current)
+ final[diff & remaining_border] = dilated[diff & remaining_border]
+ remaining_border[diff] = 0
+ current = dilated
+ dilated_mm = [dilated_mm[i] + spacing[i] if strel_size[i] == 1 else dilated_mm[i] for i in range(3)]
+
+ # what can happen is that a cell is so small that the network only predicted border and no core. This cell will be
+ # fused with the nearest other instance, which we don't want. Therefore we identify isolated border predictions and
+ # give them a separate instance id
+ # we identify isolated border predictions by checking each foreground object in arr and see whether this object
+ # also contains label 1
+ max_label = np.max(final)
+
+ foreground_objects = label((arr != 0).astype(int))
+ for i in np.unique(foreground_objects):
+ if i > 0 and (1 not in np.unique(arr[foreground_objects==i])):
+ size_of_object = np.sum(foreground_objects==i)
+ if size_of_object >= isolated_border_as_separate_instance_threshold:
+ final[foreground_objects == i] = max_label + 1
+ max_label += 1
+ #print('yeah boi')
+
+ return final.astype(np.uint32)
+
+
+def load_instanceseg_save(in_file: str, out_file:str, better: bool):
+ itk_img = sitk.ReadImage(in_file)
+ if not better:
+ instanceseg = convert_to_instance_seg(sitk.GetArrayFromImage(itk_img))
+ else:
+ instanceseg = convert_to_instance_seg2(sitk.GetArrayFromImage(itk_img))
+ itk_out = sitk.GetImageFromArray(instanceseg)
+ itk_out.CopyInformation(itk_img)
+ sitk.WriteImage(itk_out, out_file)
+
+
+def convert_all_to_instance(input_folder: str, output_folder: str, processes: int = 24, better: bool = False):
+ maybe_mkdir_p(output_folder)
+ p = Pool(processes)
+ files = subfiles(input_folder, suffix='.nii.gz', join=False)
+ output_files = [join(output_folder, i) for i in files]
+ input_files = [join(input_folder, i) for i in files]
+ better = [better] * len(files)
+ r = p.starmap_async(load_instanceseg_save, zip(input_files, output_files, better))
+ _ = r.get()
+ p.close()
+ p.join()
+
+
+if __name__ == "__main__":
+ base = "/home/fabian/data/Fluo-N3DH-SIM"
+ task_id = 76
+ task_name = 'Fluo_N3DH_SIM'
+ spacing = (0.2, 0.125, 0.125)
+ border_thickness = 0.5
+
+ prepare_task(base, task_id, task_name, spacing, border_thickness, 12)
+
+ # we need custom splits
+ task_name = "Task076_Fluo_N3DH_SIM"
+ labelsTr = join(nnUNet_raw_data, task_name, "labelsTr")
+ cases = subfiles(labelsTr, suffix='.nii.gz', join=False)
+ splits = []
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('01_')],
+ 'val': [i[:-7] for i in cases if i.startswith('02_')]}
+ )
+ splits.append(
+ {'train': [i[:-7] for i in cases if i.startswith('02_')],
+ 'val': [i[:-7] for i in cases if i.startswith('01_')]}
+ )
+
+ maybe_mkdir_p(join(preprocessing_output_dir, task_name))
+
+ save_pickle(splits, join(preprocessing_output_dir, task_name, "splits_final.pkl"))
+
+ # test set was converted to instance seg with convert_all_to_instance with better=True
+
+ # convert to tiff with convert_to_tiff
+
+
+
diff --git a/nnunet/dataset_conversion/Task082_BraTS_2020.py b/nnunet/dataset_conversion/Task082_BraTS_2020.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ba20f93f496287fb063c11909b23c1c393a2a90
--- /dev/null
+++ b/nnunet/dataset_conversion/Task082_BraTS_2020.py
@@ -0,0 +1,751 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import shutil
+from collections import OrderedDict
+from copy import deepcopy
+from multiprocessing.pool import Pool
+from typing import Tuple
+
+import SimpleITK as sitk
+import numpy as np
+import scipy.stats as ss
+from batchgenerators.utilities.file_and_folder_operations import *
+from medpy.metric import dc, hd95
+from nnunet.dataset_conversion.Task032_BraTS_2018 import convert_labels_back_to_BraTS_2018_2019_convention
+from nnunet.dataset_conversion.Task043_BraTS_2019 import copy_BraTS_segmentation_and_convert_labels
+from nnunet.evaluation.region_based_evaluation import get_brats_regions, evaluate_regions
+from nnunet.paths import nnUNet_raw_data
+from nnunet.postprocessing.consolidate_postprocessing import collect_cv_niftis
+
+
+def apply_brats_threshold(fname, out_dir, threshold, replace_with):
+ img_itk = sitk.ReadImage(fname)
+ img_npy = sitk.GetArrayFromImage(img_itk)
+ s = np.sum(img_npy == 3)
+ if s < threshold:
+ # print(s, fname)
+ img_npy[img_npy == 3] = replace_with
+ img_itk_postprocessed = sitk.GetImageFromArray(img_npy)
+ img_itk_postprocessed.CopyInformation(img_itk)
+ sitk.WriteImage(img_itk_postprocessed, join(out_dir, fname.split("/")[-1]))
+
+
+def load_niftis_threshold_compute_dice(gt_file, pred_file, thresholds: Tuple[list, tuple]):
+ gt = sitk.GetArrayFromImage(sitk.ReadImage(gt_file))
+ pred = sitk.GetArrayFromImage(sitk.ReadImage(pred_file))
+ mask_pred = pred == 3
+ mask_gt = gt == 3
+ num_pred = np.sum(mask_pred)
+
+ num_gt = np.sum(mask_gt)
+ dice = dc(mask_pred, mask_gt)
+
+ res_dice = {}
+ res_was_smaller = {}
+
+ for t in thresholds:
+ was_smaller = False
+
+ if num_pred < t:
+ was_smaller = True
+ if num_gt == 0:
+ dice_here = 1.
+ else:
+ dice_here = 0.
+ else:
+ dice_here = deepcopy(dice)
+
+ res_dice[t] = dice_here
+ res_was_smaller[t] = was_smaller
+
+ return res_was_smaller, res_dice
+
+
+def apply_threshold_to_folder(folder_in, folder_out, threshold, replace_with, processes=24):
+ maybe_mkdir_p(folder_out)
+ niftis = subfiles(folder_in, suffix='.nii.gz', join=True)
+
+ p = Pool(processes)
+ p.starmap(apply_brats_threshold, zip(niftis, [folder_out]*len(niftis), [threshold]*len(niftis), [replace_with] * len(niftis)))
+
+ p.close()
+ p.join()
+
+
+def determine_brats_postprocessing(folder_with_preds, folder_with_gt, postprocessed_output_dir, processes=8,
+ thresholds=(0, 10, 50, 100, 200, 500, 750, 1000, 1500, 2500, 10000), replace_with=2):
+ # find pairs
+ nifti_gt = subfiles(folder_with_gt, suffix=".nii.gz", sort=True)
+
+ p = Pool(processes)
+
+ nifti_pred = subfiles(folder_with_preds, suffix='.nii.gz', sort=True)
+
+ results = p.starmap_async(load_niftis_threshold_compute_dice, zip(nifti_gt, nifti_pred, [thresholds] * len(nifti_pred)))
+ results = results.get()
+
+ all_dc_per_threshold = {}
+ for t in thresholds:
+ all_dc_per_threshold[t] = np.array([i[1][t] for i in results])
+ print(t, np.mean(all_dc_per_threshold[t]))
+
+ means = [np.mean(all_dc_per_threshold[t]) for t in thresholds]
+ best_threshold = thresholds[np.argmax(means)]
+ print('best', best_threshold, means[np.argmax(means)])
+
+ maybe_mkdir_p(postprocessed_output_dir)
+
+ p.starmap(apply_brats_threshold, zip(nifti_pred, [postprocessed_output_dir]*len(nifti_pred), [best_threshold]*len(nifti_pred), [replace_with] * len(nifti_pred)))
+
+ p.close()
+ p.join()
+
+ save_pickle((thresholds, means, best_threshold, all_dc_per_threshold), join(postprocessed_output_dir, "threshold.pkl"))
+
+
+def collect_and_prepare(base_dir, num_processes = 12, clean=False):
+ """
+ collect all cv_niftis, compute brats metrics, compute enh tumor thresholds and summarize in csv
+ :param base_dir:
+ :return:
+ """
+ out = join(base_dir, 'cv_results')
+ out_pp = join(base_dir, 'cv_results_pp')
+ experiments = subfolders(base_dir, join=False, prefix='nnUNetTrainer')
+ regions = get_brats_regions()
+ gt_dir = join(base_dir, 'gt_niftis')
+ replace_with = 2
+
+ failed = []
+ successful = []
+ for e in experiments:
+ print(e)
+ try:
+ o = join(out, e)
+ o_p = join(out_pp, e)
+ maybe_mkdir_p(o)
+ maybe_mkdir_p(o_p)
+ collect_cv_niftis(join(base_dir, e), o)
+ if clean or not isfile(join(o, 'summary.csv')):
+ evaluate_regions(o, gt_dir, regions, num_processes)
+ if clean or not isfile(join(o_p, 'threshold.pkl')):
+ determine_brats_postprocessing(o, gt_dir, o_p, num_processes, thresholds=list(np.arange(0, 760, 10)), replace_with=replace_with)
+ if clean or not isfile(join(o_p, 'summary.csv')):
+ evaluate_regions(o_p, gt_dir, regions, num_processes)
+ successful.append(e)
+ except Exception as ex:
+ print("\nERROR\n", e, ex, "\n")
+ failed.append(e)
+
+ # we are interested in the mean (nan is 1) column
+ with open(join(base_dir, 'cv_summary.csv'), 'w') as f:
+ f.write('name,whole,core,enh,mean\n')
+ for e in successful:
+ expected_nopp = join(out, e, 'summary.csv')
+ expected_pp = join(out, out_pp, e, 'summary.csv')
+ if isfile(expected_nopp):
+ res = np.loadtxt(expected_nopp, dtype=str, skiprows=0, delimiter=',')[-2]
+ as_numeric = [float(i) for i in res[1:]]
+ f.write(e + '_noPP,')
+ f.write("%0.4f," % as_numeric[0])
+ f.write("%0.4f," % as_numeric[1])
+ f.write("%0.4f," % as_numeric[2])
+ f.write("%0.4f\n" % np.mean(as_numeric))
+ if isfile(expected_pp):
+ res = np.loadtxt(expected_pp, dtype=str, skiprows=0, delimiter=',')[-2]
+ as_numeric = [float(i) for i in res[1:]]
+ f.write(e + '_PP,')
+ f.write("%0.4f," % as_numeric[0])
+ f.write("%0.4f," % as_numeric[1])
+ f.write("%0.4f," % as_numeric[2])
+ f.write("%0.4f\n" % np.mean(as_numeric))
+
+ # this just crawls the folders and evaluates what it finds
+ with open(join(base_dir, 'cv_summary2.csv'), 'w') as f:
+ for folder in ['cv_results', 'cv_results_pp']:
+ for ex in subdirs(join(base_dir, folder), join=False):
+ print(folder, ex)
+ expected = join(base_dir, folder, ex, 'summary.csv')
+ if clean or not isfile(expected):
+ evaluate_regions(join(base_dir, folder, ex), gt_dir, regions, num_processes)
+ if isfile(expected):
+ res = np.loadtxt(expected, dtype=str, skiprows=0, delimiter=',')[-2]
+ as_numeric = [float(i) for i in res[1:]]
+ f.write('%s__%s,' % (folder, ex))
+ f.write("%0.4f," % as_numeric[0])
+ f.write("%0.4f," % as_numeric[1])
+ f.write("%0.4f," % as_numeric[2])
+ f.write("%0.4f\n" % np.mean(as_numeric))
+
+ f.write('name,whole,core,enh,mean\n')
+ for e in successful:
+ expected_nopp = join(out, e, 'summary.csv')
+ expected_pp = join(out, out_pp, e, 'summary.csv')
+ if isfile(expected_nopp):
+ res = np.loadtxt(expected_nopp, dtype=str, skiprows=0, delimiter=',')[-2]
+ as_numeric = [float(i) for i in res[1:]]
+ f.write(e + '_noPP,')
+ f.write("%0.4f," % as_numeric[0])
+ f.write("%0.4f," % as_numeric[1])
+ f.write("%0.4f," % as_numeric[2])
+ f.write("%0.4f\n" % np.mean(as_numeric))
+ if isfile(expected_pp):
+ res = np.loadtxt(expected_pp, dtype=str, skiprows=0, delimiter=',')[-2]
+ as_numeric = [float(i) for i in res[1:]]
+ f.write(e + '_PP,')
+ f.write("%0.4f," % as_numeric[0])
+ f.write("%0.4f," % as_numeric[1])
+ f.write("%0.4f," % as_numeric[2])
+ f.write("%0.4f\n" % np.mean(as_numeric))
+
+ # apply threshold to val set
+ expected_num_cases = 125
+ missing_valset = []
+ has_val_pred = []
+ for e in successful:
+ if isdir(join(base_dir, 'predVal', e)):
+ currdir = join(base_dir, 'predVal', e)
+ files = subfiles(currdir, suffix='.nii.gz', join=False)
+ if len(files) != expected_num_cases:
+ print(e, 'prediction not done, found %d files, expected %s' % (len(files), expected_num_cases))
+ continue
+ output_folder = join(base_dir, 'predVal_PP', e)
+ maybe_mkdir_p(output_folder)
+ threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
+ if threshold > 1000: threshold = 750 # don't make it too big!
+ apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)
+ has_val_pred.append(e)
+ else:
+ print(e, 'has no valset predictions')
+ missing_valset.append(e)
+
+ # 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold' needs special treatment
+ e = 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5'
+ currdir = join(base_dir, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
+ output_folder = join(base_dir, 'predVal_PP', 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
+ maybe_mkdir_p(output_folder)
+ threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
+ if threshold > 1000: threshold = 750 # don't make it too big!
+ apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)
+
+ # 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold' needs special treatment
+ e = 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5'
+ currdir = join(base_dir, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
+ output_folder = join(base_dir, 'predVal_PP', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
+ maybe_mkdir_p(output_folder)
+ threshold = load_pickle(join(out_pp, e, 'threshold.pkl'))[2]
+ if threshold > 1000: threshold = 750 # don't make it too big!
+ apply_threshold_to_folder(currdir, output_folder, threshold, replace_with, num_processes)
+
+ # convert val set to brats labels for submission
+ output_converted = join(base_dir, 'converted_valSet')
+
+ for source in ['predVal', 'predVal_PP']:
+ for e in has_val_pred + ['nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold']:
+ expected_source_folder = join(base_dir, source, e)
+ if not isdir(expected_source_folder):
+ print(e, 'has no', source)
+ raise RuntimeError()
+ files = subfiles(expected_source_folder, suffix='.nii.gz', join=False)
+ if len(files) != expected_num_cases:
+ print(e, 'prediction not done, found %d files, expected %s' % (len(files), expected_num_cases))
+ continue
+ target_folder = join(output_converted, source, e)
+ maybe_mkdir_p(target_folder)
+ convert_labels_back_to_BraTS_2018_2019_convention(expected_source_folder, target_folder)
+
+ summarize_validation_set_predictions(output_converted)
+
+
+def summarize_validation_set_predictions(base):
+ with open(join(base, 'summary.csv'), 'w') as f:
+ f.write('name,whole,core,enh,mean,whole,core,enh,mean\n')
+ for subf in subfolders(base, join=False):
+ for e in subfolders(join(base, subf), join=False):
+ expected = join(base, subf, e, 'Stats_Validation_final.csv')
+ if not isfile(expected):
+ print(subf, e, 'has missing csv')
+ continue
+ a = np.loadtxt(expected, delimiter=',', dtype=str)
+ assert a.shape[0] == 131, 'did not evaluate all 125 cases!'
+ selected_row = a[-5]
+ values = [float(i) for i in selected_row[1:4]]
+ f.write(e + "_" + subf + ',')
+ f.write("%0.4f," % values[1])
+ f.write("%0.4f," % values[2])
+ f.write("%0.4f," % values[0])
+ f.write("%0.4f," % np.mean(values))
+ values = [float(i) for i in selected_row[-3:]]
+ f.write("%0.4f," % values[1])
+ f.write("%0.4f," % values[2])
+ f.write("%0.4f," % values[0])
+ f.write("%0.4f\n" % np.mean(values))
+
+
+def compute_BraTS_dice(ref, pred):
+ """
+ ref and gt are binary integer numpy.ndarray s
+ :param ref:
+ :param gt:
+ :return:
+ """
+ num_ref = np.sum(ref)
+ num_pred = np.sum(pred)
+
+ if num_ref == 0:
+ if num_pred == 0:
+ return 1
+ else:
+ return 0
+ else:
+ return dc(pred, ref)
+
+
+def convert_all_to_BraTS(input_folder, output_folder, expected_num_cases=125):
+ for s in subdirs(input_folder, join=False):
+ nii = subfiles(join(input_folder, s), suffix='.nii.gz', join=False)
+ if len(nii) != expected_num_cases:
+ print(s)
+ else:
+ target_dir = join(output_folder, s)
+ convert_labels_back_to_BraTS_2018_2019_convention(join(input_folder, s), target_dir, num_processes=6)
+
+
+def compute_BraTS_HD95(ref, pred):
+ """
+ ref and gt are binary integer numpy.ndarray s
+ spacing is assumed to be (1, 1, 1)
+ :param ref:
+ :param pred:
+ :return:
+ """
+ num_ref = np.sum(ref)
+ num_pred = np.sum(pred)
+
+ if num_ref == 0:
+ if num_pred == 0:
+ return 0
+ else:
+ return 373.12866
+ elif num_pred == 0 and num_ref != 0:
+ return 373.12866
+ else:
+ return hd95(pred, ref, (1, 1, 1))
+
+
+def evaluate_BraTS_case(arr: np.ndarray, arr_gt: np.ndarray):
+ """
+ attempting to reimplement the brats evaluation scheme
+ assumes edema=1, non_enh=2, enh=3
+ :param arr:
+ :param arr_gt:
+ :return:
+ """
+ # whole tumor
+ mask_gt = (arr_gt != 0).astype(int)
+ mask_pred = (arr != 0).astype(int)
+ dc_whole = compute_BraTS_dice(mask_gt, mask_pred)
+ hd95_whole = compute_BraTS_HD95(mask_gt, mask_pred)
+ del mask_gt, mask_pred
+
+ # tumor core
+ mask_gt = (arr_gt > 1).astype(int)
+ mask_pred = (arr > 1).astype(int)
+ dc_core = compute_BraTS_dice(mask_gt, mask_pred)
+ hd95_core = compute_BraTS_HD95(mask_gt, mask_pred)
+ del mask_gt, mask_pred
+
+ # enhancing
+ mask_gt = (arr_gt == 3).astype(int)
+ mask_pred = (arr == 3).astype(int)
+ dc_enh = compute_BraTS_dice(mask_gt, mask_pred)
+ hd95_enh = compute_BraTS_HD95(mask_gt, mask_pred)
+ del mask_gt, mask_pred
+
+ return dc_whole, dc_core, dc_enh, hd95_whole, hd95_core, hd95_enh
+
+
+def load_evaluate(filename_gt: str, filename_pred: str):
+ arr_pred = sitk.GetArrayFromImage(sitk.ReadImage(filename_pred))
+ arr_gt = sitk.GetArrayFromImage(sitk.ReadImage(filename_gt))
+ return evaluate_BraTS_case(arr_pred, arr_gt)
+
+
+def evaluate_BraTS_folder(folder_pred, folder_gt, num_processes: int = 24, strict=False):
+ nii_pred = subfiles(folder_pred, suffix='.nii.gz', join=False)
+ if len(nii_pred) == 0:
+ return
+ nii_gt = subfiles(folder_gt, suffix='.nii.gz', join=False)
+ assert all([i in nii_gt for i in nii_pred]), 'not all predicted niftis have a reference file!'
+ if strict:
+ assert all([i in nii_pred for i in nii_gt]), 'not all gt niftis have a predicted file!'
+ p = Pool(num_processes)
+ nii_pred_fullpath = [join(folder_pred, i) for i in nii_pred]
+ nii_gt_fullpath = [join(folder_gt, i) for i in nii_pred]
+ results = p.starmap(load_evaluate, zip(nii_gt_fullpath, nii_pred_fullpath))
+ # now write to output file
+ with open(join(folder_pred, 'results.csv'), 'w') as f:
+ f.write("name,dc_whole,dc_core,dc_enh,hd95_whole,hd95_core,hd95_enh\n")
+ for fname, r in zip(nii_pred, results):
+ f.write(fname)
+ f.write(",%0.4f,%0.4f,%0.4f,%3.3f,%3.3f,%3.3f\n" % r)
+
+
+def load_csv_for_ranking(csv_file: str):
+ res = np.loadtxt(csv_file, dtype='str', delimiter=',')
+ scores = res[1:, [1, 2, 3, -3, -2, -1]].astype(float)
+ scores[:, -3:] *= -1
+ scores[:, -3:] += 373.129
+ assert np.all(scores <= 373.129)
+ assert np.all(scores >= 0)
+ return scores
+
+
+def rank_algorithms(data:np.ndarray):
+ """
+ data is (metrics x experiments x cases)
+ :param data:
+ :return:
+ """
+ num_metrics, num_experiments, num_cases = data.shape
+ ranks = np.zeros((num_metrics, num_experiments))
+ for m in range(6):
+ r = np.apply_along_axis(ss.rankdata, 0, -data[m], 'min')
+ ranks[m] = r.mean(1)
+ average_rank = np.mean(ranks, 0)
+ final_ranks = ss.rankdata(average_rank, 'min')
+ return final_ranks, average_rank, ranks
+
+
+def score_and_postprocess_model_based_on_rank_then_aggregate():
+ """
+ Similarly to BraTS 2017 - BraTS 2019, each participant will be ranked for each of the X test cases. Each case
+ includes 3 regions of evaluation, and the metrics used to produce the rankings will be the Dice Similarity
+ Coefficient and the 95% Hausdorff distance. Thus, for X number of cases included in the BraTS 2020, each
+ participant ends up having X*3*2 rankings. The final ranking score is the average of all these rankings normalized
+ by the number of teams.
+ https://zenodo.org/record/3718904
+
+ -> let's optimize for this.
+
+ Important: the outcome very much depends on the competing models. We need some references. We only got our own,
+ so let's hope this still works
+ :return:
+ """
+ base = "/media/fabian/Results/nnUNet/3d_fullres/Task082_BraTS2020"
+ replace_with = 2
+ num_processes = 24
+ expected_num_cases_val = 125
+
+ # use a separate output folder from the previous experiments to ensure we are not messing things up
+ output_base_here = join(base, 'use_brats_ranking')
+ maybe_mkdir_p(output_base_here)
+
+ # collect cv niftis and compute metrics with evaluate_BraTS_folder to ensure we work with the same metrics as brats
+ out = join(output_base_here, 'cv_results')
+ experiments = subfolders(base, join=False, prefix='nnUNetTrainer')
+ gt_dir = join(base, 'gt_niftis')
+
+ experiments_with_full_cv = []
+ for e in experiments:
+ print(e)
+ o = join(out, e)
+ maybe_mkdir_p(o)
+ try:
+ collect_cv_niftis(join(base, e), o)
+ if not isfile(join(o, 'results.csv')):
+ evaluate_BraTS_folder(o, gt_dir, num_processes, strict=True)
+ experiments_with_full_cv.append(e)
+ except Exception as ex:
+ print("\nERROR\n", e, ex, "\n")
+ if isfile(join(o, 'results.csv')):
+ os.remove(join(o, 'results.csv'))
+
+ # rank the non-postprocessed models
+ tmp = np.loadtxt(join(out, experiments_with_full_cv[0], 'results.csv'), dtype='str', delimiter=',')
+ num_cases = len(tmp) - 1
+ data_for_ranking = np.zeros((6, len(experiments_with_full_cv), num_cases))
+ for i, e in enumerate(experiments_with_full_cv):
+ scores = load_csv_for_ranking(join(out, e, 'results.csv'))
+ for metric in range(6):
+ data_for_ranking[metric, i] = scores[:, metric]
+
+ final_ranks, average_rank, ranks = rank_algorithms(data_for_ranking)
+
+ for t in np.argsort(final_ranks):
+ print(final_ranks[t], average_rank[t], experiments_with_full_cv[t])
+
+ # for each model, create output directories with different thresholds. evaluate ALL OF THEM (might take a while lol)
+ thresholds = np.arange(25, 751, 25)
+ output_pp_tmp = join(output_base_here, 'cv_determine_pp_thresholds')
+ for e in experiments_with_full_cv:
+ input_folder = join(out, e)
+ for t in thresholds:
+ output_directory = join(output_pp_tmp, e, str(t))
+ maybe_mkdir_p(output_directory)
+ if not isfile(join(output_directory, 'results.csv')):
+ apply_threshold_to_folder(input_folder, output_directory, t, replace_with, processes=16)
+ evaluate_BraTS_folder(output_directory, gt_dir, num_processes)
+
+ # load ALL the results!
+ results = []
+ experiment_names = []
+ for e in experiments_with_full_cv:
+ for t in thresholds:
+ output_directory = join(output_pp_tmp, e, str(t))
+ expected_file = join(output_directory, 'results.csv')
+ if not isfile(expected_file):
+ print(e, 'does not have a results file for threshold', t)
+ continue
+ results.append(load_csv_for_ranking(expected_file))
+ experiment_names.append("%s___%d" % (e, t))
+ all_results = np.concatenate([i[None] for i in results], 0).transpose((2, 0, 1))
+
+ # concatenate with non postprocessed models
+ all_results = np.concatenate((data_for_ranking, all_results), 1)
+ experiment_names += experiments_with_full_cv
+
+ final_ranks, average_rank, ranks = rank_algorithms(all_results)
+
+ for t in np.argsort(final_ranks):
+ print(final_ranks[t], average_rank[t], experiment_names[t])
+
+ # for each model, print the non postprocessed model as well as the best postprocessed model. If there are
+ # validation set predictions, apply the best threshold to the validation set
+ pred_val_base = join(base, 'predVal_PP_rank')
+ has_val_pred = []
+ for e in experiments_with_full_cv:
+ rank_nonpp = final_ranks[experiment_names.index(e)]
+ avg_rank_nonpp = average_rank[experiment_names.index(e)]
+ print(e, avg_rank_nonpp, rank_nonpp)
+ predicted_val = join(base, 'predVal', e)
+
+ pp_models = [j for j, i in enumerate(experiment_names) if i.split("___")[0] == e and i != e]
+ if len(pp_models) > 0:
+ ranks = [final_ranks[i] for i in pp_models]
+ best_idx = np.argmin(ranks)
+ best = experiment_names[pp_models[best_idx]]
+ best_avg_rank = average_rank[pp_models[best_idx]]
+ print(best, best_avg_rank, min(ranks))
+ print('')
+ # apply threshold to validation set
+ best_threshold = int(best.split('___')[-1])
+ if not isdir(predicted_val):
+ print(e, 'has not valset predictions')
+ else:
+ files = subfiles(predicted_val, suffix='.nii.gz')
+ if len(files) != expected_num_cases_val:
+ print(e, 'has missing val cases. found: %d expected: %d' % (len(files), expected_num_cases_val))
+ else:
+ apply_threshold_to_folder(predicted_val, join(pred_val_base, e), best_threshold, replace_with, num_processes)
+ has_val_pred.append(e)
+ else:
+ print(e, 'not found in ranking')
+
+ # apply nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5 to nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold
+ e = 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5'
+ pp_models = [j for j, i in enumerate(experiment_names) if i.split("___")[0] == e and i != e]
+ ranks = [final_ranks[i] for i in pp_models]
+ best_idx = np.argmin(ranks)
+ best = experiment_names[pp_models[best_idx]]
+ best_avg_rank = average_rank[pp_models[best_idx]]
+ best_threshold = int(best.split('___')[-1])
+ predicted_val = join(base, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
+ apply_threshold_to_folder(predicted_val, join(pred_val_base, 'nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold'), best_threshold, replace_with, num_processes)
+ has_val_pred.append('nnUNetTrainerV2BraTSRegions_DA3_BN__nnUNetPlansv2.1_bs5_15fold')
+
+ # apply nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5 to nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold
+ e = 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5'
+ pp_models = [j for j, i in enumerate(experiment_names) if i.split("___")[0] == e and i != e]
+ ranks = [final_ranks[i] for i in pp_models]
+ best_idx = np.argmin(ranks)
+ best = experiment_names[pp_models[best_idx]]
+ best_avg_rank = average_rank[pp_models[best_idx]]
+ best_threshold = int(best.split('___')[-1])
+ predicted_val = join(base, 'predVal', 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
+ apply_threshold_to_folder(predicted_val, join(pred_val_base, 'nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold'), best_threshold, replace_with, num_processes)
+ has_val_pred.append('nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold')
+
+ # convert valsets
+ output_converted = join(base, 'converted_valSet')
+ for e in has_val_pred:
+ expected_source_folder = join(base, 'predVal_PP_rank', e)
+ if not isdir(expected_source_folder):
+ print(e, 'has no predVal_PP_rank')
+ raise RuntimeError()
+ files = subfiles(expected_source_folder, suffix='.nii.gz', join=False)
+ if len(files) != expected_num_cases_val:
+ print(e, 'prediction not done, found %d files, expected %s' % (len(files), expected_num_cases_val))
+ continue
+ target_folder = join(output_converted, 'predVal_PP_rank', e)
+ maybe_mkdir_p(target_folder)
+ convert_labels_back_to_BraTS_2018_2019_convention(expected_source_folder, target_folder)
+
+ # now load all the csvs for the validation set (obtained from evaluation platform) and rank our models on the
+ # validation set
+ flds = subdirs(output_converted, join=False)
+ results_valset = []
+ names_valset = []
+ for f in flds:
+ curr = join(output_converted, f)
+ experiments = subdirs(curr, join=False)
+ for e in experiments:
+ currr = join(curr, e)
+ expected_file = join(currr, 'Stats_Validation_final.csv')
+ if not isfile(expected_file):
+ print(f, e, "has not been evaluated yet!")
+ else:
+ res = load_csv_for_ranking(expected_file)[:-5]
+ assert res.shape[0] == expected_num_cases_val
+ results_valset.append(res[None])
+ names_valset.append("%s___%s" % (f, e))
+ results_valset = np.concatenate(results_valset, 0) # experiments x cases x metrics
+ # convert to metrics x experiments x cases
+ results_valset = results_valset.transpose((2, 0, 1))
+ final_ranks, average_rank, ranks = rank_algorithms(results_valset)
+ for t in np.argsort(final_ranks):
+ print(final_ranks[t], average_rank[t], names_valset[t])
+
+
+if __name__ == "__main__":
+ """
+ THIS CODE IS A MESS. IT IS PROVIDED AS IS WITH NO GUARANTEES. YOU HAVE TO DIG THROUGH IT YOURSELF. GOOD LUCK ;-)
+
+ REMEMBER TO CONVERT LABELS BACK TO BRATS CONVENTION AFTER PREDICTION!
+ """
+
+ task_name = "Task082_BraTS2020"
+ downloaded_data_dir = "/home/fabian/Downloads/MICCAI_BraTS2020_TrainingData"
+ downloaded_data_dir_val = "/home/fabian/Downloads/MICCAI_BraTS2020_ValidationData"
+
+ target_base = join(nnUNet_raw_data, task_name)
+ target_imagesTr = join(target_base, "imagesTr")
+ target_imagesVal = join(target_base, "imagesVal")
+ target_imagesTs = join(target_base, "imagesTs")
+ target_labelsTr = join(target_base, "labelsTr")
+
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_imagesVal)
+ maybe_mkdir_p(target_imagesTs)
+ maybe_mkdir_p(target_labelsTr)
+
+ patient_names = []
+ cur = join(downloaded_data_dir)
+ for p in subdirs(cur, join=False):
+ patdir = join(cur, p)
+ patient_name = p
+ patient_names.append(patient_name)
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+ seg = join(patdir, p + "_seg.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ isfile(seg)
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTr, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTr, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTr, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTr, patient_name + "_0003.nii.gz"))
+
+ copy_BraTS_segmentation_and_convert_labels(seg, join(target_labelsTr, patient_name + ".nii.gz"))
+
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "BraTS2020"
+ json_dict['description'] = "nothing"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see BraTS2020"
+ json_dict['licence'] = "see BraTS2020 license"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "T1",
+ "1": "T1ce",
+ "2": "T2",
+ "3": "FLAIR"
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "edema",
+ "2": "non-enhancing",
+ "3": "enhancing",
+ }
+ json_dict['numTraining'] = len(patient_names)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ patient_names]
+ json_dict['test'] = []
+
+ save_json(json_dict, join(target_base, "dataset.json"))
+
+ if downloaded_data_dir_val is not None:
+ for p in subdirs(downloaded_data_dir_val, join=False):
+ patdir = join(downloaded_data_dir_val, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesVal, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesVal, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesVal, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesVal, patient_name + "_0003.nii.gz"))
+
+
+ downloaded_data_dir_test = "/home/fabian/Downloads/MICCAI_BraTS2020_TestingData"
+
+ if isdir(downloaded_data_dir_test):
+ for p in subdirs(downloaded_data_dir_test, join=False):
+ patdir = join(downloaded_data_dir_test, p)
+ patient_name = p
+ t1 = join(patdir, p + "_t1.nii.gz")
+ t1c = join(patdir, p + "_t1ce.nii.gz")
+ t2 = join(patdir, p + "_t2.nii.gz")
+ flair = join(patdir, p + "_flair.nii.gz")
+
+ assert all([
+ isfile(t1),
+ isfile(t1c),
+ isfile(t2),
+ isfile(flair),
+ ]), "%s" % patient_name
+
+ shutil.copy(t1, join(target_imagesTs, patient_name + "_0000.nii.gz"))
+ shutil.copy(t1c, join(target_imagesTs, patient_name + "_0001.nii.gz"))
+ shutil.copy(t2, join(target_imagesTs, patient_name + "_0002.nii.gz"))
+ shutil.copy(flair, join(target_imagesTs, patient_name + "_0003.nii.gz"))
+
+ # test set
+ # nnUNet_ensemble -f nnUNetTrainerV2BraTSRegions_DA3_BN_BD__nnUNetPlansv2.1_bs5_5fold nnUNetTrainerV2BraTSRegions_DA4_BN_BD__nnUNetPlansv2.1_bs5_5fold nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold -o ensembled_nnUNetTrainerV2BraTSRegions_DA3_BN_BD__nnUNetPlansv2.1_bs5_5fold__nnUNetTrainerV2BraTSRegions_DA4_BN_BD__nnUNetPlansv2.1_bs5_5fold__nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold
+ # apply_threshold_to_folder('ensembled_nnUNetTrainerV2BraTSRegions_DA3_BN_BD__nnUNetPlansv2.1_bs5_5fold__nnUNetTrainerV2BraTSRegions_DA4_BN_BD__nnUNetPlansv2.1_bs5_5fold__nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold/', 'ensemble_PP200/', 200, 2)
+ # convert_labels_back_to_BraTS_2018_2019_convention('ensemble_PP200/', 'ensemble_PP200_converted')
+
+ # export for publication of weights
+ # nnUNet_export_model_to_zip -tr nnUNetTrainerV2BraTSRegions_DA4_BN -pl nnUNetPlansv2.1_bs5 -f 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 -t 82 -o nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold.zip --disable_strict
+ # nnUNet_export_model_to_zip -tr nnUNetTrainerV2BraTSRegions_DA3_BN_BD -pl nnUNetPlansv2.1_bs5 -f 0 1 2 3 4 -t 82 -o nnUNetTrainerV2BraTSRegions_DA3_BN_BD__nnUNetPlansv2.1_bs5_5fold.zip --disable_strict
+ # nnUNet_export_model_to_zip -tr nnUNetTrainerV2BraTSRegions_DA4_BN_BD -pl nnUNetPlansv2.1_bs5 -f 0 1 2 3 4 -t 82 -o nnUNetTrainerV2BraTSRegions_DA4_BN_BD__nnUNetPlansv2.1_bs5_5fold.zip --disable_strict
diff --git a/nnunet/dataset_conversion/Task083_VerSe2020.py b/nnunet/dataset_conversion/Task083_VerSe2020.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d0c9806891b6da1abda17ca9797580b98d505d2
--- /dev/null
+++ b/nnunet/dataset_conversion/Task083_VerSe2020.py
@@ -0,0 +1,138 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+from collections import OrderedDict
+from copy import deepcopy
+from multiprocessing.pool import Pool
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.dataset_conversion.Task056_VerSe2019 import check_if_all_in_good_orientation, \
+ print_unique_labels_and_their_volumes
+from nnunet.paths import nnUNet_raw_data, preprocessing_output_dir
+from nnunet.utilities.image_reorientation import reorient_all_images_in_folder_to_ras
+
+
+def manually_change_plans():
+ pp_out_folder = join(preprocessing_output_dir, "Task083_VerSe2020")
+ original_plans = join(pp_out_folder, "nnUNetPlansv2.1_plans_3D.pkl")
+ assert isfile(original_plans)
+ original_plans = load_pickle(original_plans)
+
+ # let's change the network topology for lowres and fullres
+ new_plans = deepcopy(original_plans)
+ stages = len(new_plans['plans_per_stage'])
+ for s in range(stages):
+ new_plans['plans_per_stage'][s]['patch_size'] = (224, 160, 160)
+ new_plans['plans_per_stage'][s]['pool_op_kernel_sizes'] = [[2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2]] # bottleneck of 7x5x5
+ new_plans['plans_per_stage'][s]['conv_kernel_sizes'] = [[3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3]]
+ save_pickle(new_plans, join(pp_out_folder, "custom_plans_3D.pkl"))
+
+
+if __name__ == "__main__":
+ ### First we create a nnunet dataset from verse. After this the images will be all willy nilly in their
+ # orientation because that's how VerSe comes
+ base = '/home/fabian/Downloads/osfstorage-archive/'
+
+ task_id = 83
+ task_name = "VerSe2020"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+
+ for t in subdirs(join(base, 'training_data'), join=False):
+ train_patient_names_here = [i[:-len("_seg.nii.gz")] for i in
+ subfiles(join(base, "training_data", t), join=False, suffix="_seg.nii.gz")]
+ for p in train_patient_names_here:
+ curr = join(base, "training_data", t)
+ label_file = join(curr, p + "_seg.nii.gz")
+ image_file = join(curr, p + ".nii.gz")
+ shutil.copy(image_file, join(imagestr, p + "_0000.nii.gz"))
+ shutil.copy(label_file, join(labelstr, p + ".nii.gz"))
+
+ train_patient_names += train_patient_names_here
+
+ json_dict = OrderedDict()
+ json_dict['name'] = "VerSe2020"
+ json_dict['description'] = "VerSe2020"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "see challenge website"
+ json_dict['licence'] = "see challenge website"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT",
+ }
+ json_dict['labels'] = {i: str(i) for i in range(29)}
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = []
+ json_dict['training'] = [
+ {'image': "./imagesTr/%s.nii.gz" % i.split("/")[-1], "label": "./labelsTr/%s.nii.gz" % i.split("/")[-1]} for i
+ in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i.split("/")[-1] for i in []]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+
+ # now we reorient all those images to ras. This saves a pkl with the original affine. We need this information to
+ # bring our predictions into the same geometry for submission
+ reorient_all_images_in_folder_to_ras(imagestr, 16)
+ reorient_all_images_in_folder_to_ras(imagests, 16)
+ reorient_all_images_in_folder_to_ras(labelstr, 16)
+
+ # sanity check
+ check_if_all_in_good_orientation(imagestr, labelstr, join(out_base, 'sanitycheck'))
+ # looks good to me - proceed
+
+ # check the volumes of the vertebrae
+ p = Pool(6)
+ _ = p.starmap(print_unique_labels_and_their_volumes, zip(subfiles(labelstr, suffix='.nii.gz'), [1000] * 113))
+
+ # looks good
+
+ # Now we are ready to run nnU-Net
+
+ """# run this part of the code once training is done
+ folder_gt = "/media/fabian/My Book/MedicalDecathlon/nnUNet_raw_splitted/Task056_VerSe/labelsTr"
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_fullres/Task056_VerSe/nnUNetTrainerV2__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_fullres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_lowres/Task056_VerSe/nnUNetTrainerV2__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_lowres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)
+
+ folder_pred = "/home/fabian/drives/datasets/results/nnUNet/3d_cascade_fullres/Task056_VerSe/nnUNetTrainerV2CascadeFullRes__nnUNetPlansv2.1/cv_niftis_raw"
+ out_json = "/home/fabian/Task056_VerSe_3d_cascade_fullres_summary.json"
+ evaluate_verse_folder(folder_pred, folder_gt, out_json)"""
diff --git a/nnunet/dataset_conversion/Task089_Fluo-N2DH-SIM.py b/nnunet/dataset_conversion/Task089_Fluo-N2DH-SIM.py
new file mode 100644
index 0000000000000000000000000000000000000000..4505be90d88dc29c21501ace680d1f122681f46c
--- /dev/null
+++ b/nnunet/dataset_conversion/Task089_Fluo-N2DH-SIM.py
@@ -0,0 +1,290 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+from multiprocessing import Pool
+
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from skimage.io import imread
+from skimage.io import imsave
+from skimage.morphology import disk
+from skimage.morphology import erosion
+from skimage.transform import resize
+
+from nnunet.paths import nnUNet_raw_data
+
+
+def load_bmp_convert_to_nifti_borders_2d(img_file, lab_file, img_out_base, anno_out, spacing, border_thickness=0.7):
+ img = imread(img_file)
+ img_itk = sitk.GetImageFromArray(img.astype(np.float32)[None])
+ img_itk.SetSpacing(list(spacing)[::-1] + [999])
+ sitk.WriteImage(img_itk, join(img_out_base + "_0000.nii.gz"))
+
+ if lab_file is not None:
+ l = imread(lab_file)
+ borders = generate_border_as_suggested_by_twollmann_2d(l, spacing, border_thickness)
+ l[l > 0] = 1
+ l[borders == 1] = 2
+ l_itk = sitk.GetImageFromArray(l.astype(np.uint8)[None])
+ l_itk.SetSpacing(list(spacing)[::-1] + [999])
+ sitk.WriteImage(l_itk, anno_out)
+
+
+def generate_disk(spacing, radius, dtype=int):
+ radius_in_voxels = np.round(radius / np.array(spacing)).astype(int)
+ n = 2 * radius_in_voxels + 1
+ disk_iso = disk(max(n) * 2, dtype=np.float64)
+ disk_resampled = resize(disk_iso, n, 1, 'constant', 0, clip=True, anti_aliasing=False, preserve_range=True)
+ disk_resampled[disk_resampled > 0.5] = 1
+ disk_resampled[disk_resampled <= 0.5] = 0
+ return disk_resampled.astype(dtype)
+
+
+def generate_border_as_suggested_by_twollmann_2d(label_img: np.ndarray, spacing,
+ border_thickness: float = 2) -> np.ndarray:
+ border = np.zeros_like(label_img)
+ selem = generate_disk(spacing, border_thickness)
+ for l in np.unique(label_img):
+ if l == 0: continue
+ mask = (label_img == l).astype(int)
+ eroded = erosion(mask, selem)
+ border[(eroded == 0) & (mask != 0)] = 1
+ return border
+
+
+def prepare_task(base, task_id, task_name, spacing, border_thickness: float = 15):
+ p = Pool(16)
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ train_patient_names = []
+ test_patient_names = []
+ res = []
+
+ for train_sequence in [i for i in subfolders(base + "_train", join=False) if not i.endswith("_GT")]:
+ train_cases = subfiles(join(base + '_train', train_sequence), suffix=".tif", join=False)
+ for t in train_cases:
+ casename = train_sequence + "_" + t[:-4]
+ img_file = join(base + '_train', train_sequence, t)
+ lab_file = join(base + '_train', train_sequence + "_GT", "SEG", "man_seg" + t[1:])
+ if not isfile(lab_file):
+ continue
+ img_out_base = join(imagestr, casename)
+ anno_out = join(labelstr, casename + ".nii.gz")
+ res.append(
+ p.starmap_async(load_bmp_convert_to_nifti_borders_2d,
+ ((img_file, lab_file, img_out_base, anno_out, spacing, border_thickness),)))
+ train_patient_names.append(casename)
+
+ for test_sequence in [i for i in subfolders(base + "_test", join=False) if not i.endswith("_GT")]:
+ test_cases = subfiles(join(base + '_test', test_sequence), suffix=".tif", join=False)
+ for t in test_cases:
+ casename = test_sequence + "_" + t[:-4]
+ img_file = join(base + '_test', test_sequence, t)
+ lab_file = None
+ img_out_base = join(imagests, casename)
+ anno_out = None
+ res.append(
+ p.starmap_async(load_bmp_convert_to_nifti_borders_2d,
+ ((img_file, lab_file, img_out_base, anno_out, spacing, border_thickness),)))
+ test_patient_names.append(casename)
+
+ _ = [i.get() for i in res]
+
+ json_dict = {}
+ json_dict['name'] = task_name
+ json_dict['description'] = ""
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = ""
+ json_dict['licence'] = ""
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "BF",
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "cell",
+ "2": "border",
+ }
+
+ json_dict['numTraining'] = len(train_patient_names)
+ json_dict['numTest'] = len(test_patient_names)
+ json_dict['training'] = [{'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i in
+ train_patient_names]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_patient_names]
+
+ save_json(json_dict, os.path.join(out_base, "dataset.json"))
+ p.close()
+ p.join()
+
+
+def convert_to_instance_seg(arr: np.ndarray, spacing: tuple = (0.125, 0.125), small_center_threshold: int = 30,
+ isolated_border_as_separate_instance_threshold=15):
+ from skimage.morphology import label, dilation
+
+ # we first identify centers that are too small and set them to be border. This should remove false positive instances
+ objects = label((arr == 1).astype(int))
+ for o in np.unique(objects):
+ if o > 0 and np.sum(objects == o) <= small_center_threshold:
+ arr[objects == o] = 2
+
+ # 1 is core, 2 is border
+ objects = label((arr == 1).astype(int))
+ final = np.copy(objects)
+ remaining_border = arr == 2
+ current = np.copy(objects)
+ dilated_mm = np.array((0, 0))
+ spacing = np.array(spacing)
+
+ while np.sum(remaining_border) > 0:
+ strel_size = [0, 0]
+ maximum_dilation = max(dilated_mm)
+ for i in range(2):
+ if spacing[i] == min(spacing):
+ strel_size[i] = 1
+ continue
+ if dilated_mm[i] + spacing[i] / 2 < maximum_dilation:
+ strel_size[i] = 1
+ ball_here = disk(1)
+
+ if strel_size[0] == 0: ball_here = ball_here[1:2]
+ if strel_size[1] == 0: ball_here = ball_here[:, 1:2]
+
+ #print(1)
+ dilated = dilation(current, ball_here)
+ diff = (current == 0) & (dilated != current)
+ final[diff & remaining_border] = dilated[diff & remaining_border]
+ remaining_border[diff] = 0
+ current = dilated
+ dilated_mm = [dilated_mm[i] + spacing[i] if strel_size[i] == 1 else dilated_mm[i] for i in range(2)]
+
+ # what can happen is that a cell is so small that the network only predicted border and no core. This cell will be
+ # fused with the nearest other instance, which we don't want. Therefore we identify isolated border predictions and
+ # give them a separate instance id
+ # we identify isolated border predictions by checking each foreground object in arr and see whether this object
+ # also contains label 1
+ max_label = np.max(final)
+
+ foreground_objects = label((arr != 0).astype(int))
+ for i in np.unique(foreground_objects):
+ if i > 0 and (1 not in np.unique(arr[foreground_objects==i])):
+ size_of_object = np.sum(foreground_objects==i)
+ if size_of_object >= isolated_border_as_separate_instance_threshold:
+ final[foreground_objects == i] = max_label + 1
+ max_label += 1
+ #print('yeah boi')
+
+ return final.astype(np.uint32)
+
+
+def load_convert_to_instance_save(file_in: str, file_out: str, spacing):
+ img = sitk.ReadImage(file_in)
+ img_npy = sitk.GetArrayFromImage(img)
+ out = convert_to_instance_seg(img_npy[0], spacing)[None]
+ out_itk = sitk.GetImageFromArray(out.astype(np.int16))
+ out_itk.CopyInformation(img)
+ sitk.WriteImage(out_itk, file_out)
+
+
+def convert_folder_to_instanceseg(folder_in: str, folder_out: str, spacing, processes: int = 12):
+ input_files = subfiles(folder_in, suffix=".nii.gz", join=False)
+ maybe_mkdir_p(folder_out)
+ output_files = [join(folder_out, i) for i in input_files]
+ input_files = [join(folder_in, i) for i in input_files]
+ p = Pool(processes)
+ r = []
+ for i, o in zip(input_files, output_files):
+ r.append(
+ p.starmap_async(
+ load_convert_to_instance_save,
+ ((i, o, spacing),)
+ )
+ )
+ _ = [i.get() for i in r]
+ p.close()
+ p.join()
+
+
+def convert_to_tiff(nifti_image: str, output_name: str):
+ npy = sitk.GetArrayFromImage(sitk.ReadImage(nifti_image))
+ imsave(output_name, npy[0].astype(np.uint16), compress=6)
+
+
+if __name__ == "__main__":
+ base = "/home/fabian/Downloads/Fluo-N2DH-SIM+"
+ task_name = 'Fluo-N2DH-SIM'
+ spacing = (0.125, 0.125)
+
+ task_id = 999
+ border_thickness = 0.7
+ prepare_task(base, task_id, task_name, spacing, border_thickness)
+
+ task_id = 89
+ additional_time_steps = 4
+ task_name = 'Fluo-N2DH-SIM_thickborder_time'
+ full_taskname = 'Task%03.0d_' % task_id + task_name
+ output_raw = join(nnUNet_raw_data, full_taskname)
+ shutil.rmtree(output_raw)
+ shutil.copytree(join(nnUNet_raw_data, 'Task999_Fluo-N2DH-SIM_thickborder'), output_raw)
+
+ shutil.rmtree(join(nnUNet_raw_data, 'Task999_Fluo-N2DH-SIM_thickborder'))
+
+ # now add additional time information
+ for fld in ['imagesTr', 'imagesTs']:
+ curr = join(output_raw, fld)
+ for seq in ['01', '02']:
+ images = subfiles(curr, prefix=seq, join=False)
+ for i in images:
+ current_timestep = int(i.split('_')[1][1:])
+ renamed = join(curr, i.replace("_0000", "_%04.0d" % additional_time_steps))
+ shutil.move(join(curr, i), renamed)
+ for previous_timestep in range(-additional_time_steps, 0):
+ # previous time steps will already have been processed and renamed!
+ expected_filename = join(curr, seq + "_t%03.0d" % (
+ current_timestep + previous_timestep) + "_%04.0d" % additional_time_steps + ".nii.gz")
+ if not isfile(expected_filename):
+ # create empty image
+ img = sitk.ReadImage(renamed)
+ empty = sitk.GetImageFromArray(np.zeros_like(sitk.GetArrayFromImage(img)))
+ empty.CopyInformation(img)
+ sitk.WriteImage(empty, join(curr, i.replace("_0000", "_%04.0d" % (
+ additional_time_steps + previous_timestep))))
+ else:
+ shutil.copy(expected_filename, join(curr, i.replace("_0000", "_%04.0d" % (
+ additional_time_steps + previous_timestep))))
+ dataset = load_json(join(output_raw, 'dataset.json'))
+ dataset['modality'] = {
+ '0': 't_minus 4',
+ '1': 't_minus 3',
+ '2': 't_minus 2',
+ '3': 't_minus 1',
+ '4': 'frame of interest',
+ }
+ save_json(dataset, join(output_raw, 'dataset.json'))
+
+ # we do not need custom splits since we train on all training cases
+
+ # test set predictions are converted to instance seg with convert_folder_to_instanceseg
+
+ # test set predictions are converted to tiff with convert_to_tiff
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/Task114_heart_MNMs.py b/nnunet/dataset_conversion/Task114_heart_MNMs.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bc0a4c3988b53ea3179fa761962fcb8647765ca
--- /dev/null
+++ b/nnunet/dataset_conversion/Task114_heart_MNMs.py
@@ -0,0 +1,262 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+from collections import OrderedDict
+
+import numpy as np
+import pandas as pd
+from batchgenerators.utilities.file_and_folder_operations import *
+from numpy.random.mtrand import RandomState
+
+from nnunet.experiment_planning.common_utils import split_4d_nifti
+
+
+def get_mnms_data(data_root):
+ files_raw = []
+ files_gt = []
+ for r, dirs, files in os.walk(data_root):
+ for f in files:
+ if f.endswith('nii.gz'):
+ file_path = os.path.join(r, f)
+ if '_gt' in f:
+ files_gt.append(file_path)
+ else:
+ files_raw.append(file_path)
+ return files_raw, files_gt
+
+
+def generate_filename_for_nnunet(pat_id, ts, pat_folder=None, add_zeros=False, vendor=None, centre=None, mode='mnms',
+ data_format='nii.gz'):
+ if not vendor or not centre:
+ if add_zeros:
+ filename = "{}_{}_0000.{}".format(pat_id, str(ts).zfill(4), data_format)
+ else:
+ filename = "{}_{}.{}".format(pat_id, str(ts).zfill(4), data_format)
+ else:
+ if mode == 'mnms':
+ if add_zeros:
+ filename = "{}_{}_{}_{}_0000.{}".format(pat_id, str(ts).zfill(4), vendor, centre, data_format)
+ else:
+ filename = "{}_{}_{}_{}.{}".format(pat_id, str(ts).zfill(4), vendor, centre, data_format)
+ else:
+ if add_zeros:
+ filename = "{}_{}_{}_{}_0000.{}".format(vendor, centre, pat_id, str(ts).zfill(4), data_format)
+ else:
+ filename = "{}_{}_{}_{}.{}".format(vendor, centre, pat_id, str(ts).zfill(4), data_format)
+
+ if pat_folder:
+ filename = os.path.join(pat_folder, filename)
+ return filename
+
+
+def select_annotated_frames_mms(data_folder, out_folder, add_zeros=False, is_gt=False,
+ df_path="/media/full/tera2/data/challenges/mms/Training-corrected_original/M&Ms Dataset Information.xlsx",
+ mode='mnms',):
+ table = pd.read_excel(df_path, index_col='External code')
+
+ for idx in table.index:
+ ed = table.loc[idx, 'ED']
+ es = table.loc[idx, 'ES']
+ vendor = table.loc[idx, 'Vendor']
+ centre = table.loc[idx, 'Centre']
+
+ if vendor != "C": # vendor C is for test data
+
+ # this step is needed in case of M&Ms data to adjust it to the nnUNet frame work
+ # generate old filename (w/o vendor and centre)
+ if is_gt:
+ add_to_name = 'sa_gt'
+ else:
+ add_to_name = 'sa'
+ filename_ed_original = os.path.join(
+ data_folder, "{}_{}_{}.nii.gz".format(idx, add_to_name, str(ed).zfill(4)))
+ filename_es_original = os.path.join(
+ data_folder, "{}_{}_{}.nii.gz".format(idx, add_to_name, str(es).zfill(4)))
+
+ # generate new filename with vendor and centre
+ filename_ed = generate_filename_for_nnunet(pat_id=idx, ts=ed, pat_folder=out_folder,
+ vendor=vendor, centre=centre, add_zeros=add_zeros, mode=mode)
+ filename_es = generate_filename_for_nnunet(pat_id=idx, ts=es, pat_folder=out_folder,
+ vendor=vendor, centre=centre, add_zeros=add_zeros, mode=mode)
+
+ shutil.copy(filename_ed_original, filename_ed)
+ shutil.copy(filename_es_original, filename_es)
+
+
+def create_custom_splits_for_experiments(task_path):
+ data_keys = [i[:-4] for i in
+ subfiles(os.path.join(task_path, "nnUNetData_plans_v2.1_2D_stage0"),
+ join=False, suffix='npz')]
+ existing_splits = os.path.join(task_path, "splits_final.pkl")
+
+ splits = load_pickle(existing_splits)
+ splits = splits[:5] # discard old changes
+
+ unique_a_only = np.unique([i.split('_')[0] for i in data_keys if i.find('_A_') != -1])
+ unique_b_only = np.unique([i.split('_')[0] for i in data_keys if i.find('_B_') != -1])
+
+ num_train_a = int(np.round(0.8 * len(unique_a_only)))
+ num_train_b = int(np.round(0.8 * len(unique_b_only)))
+
+ p = RandomState(1234)
+ idx_a_train = p.choice(len(unique_a_only), num_train_a, replace=False)
+ idx_b_train = p.choice(len(unique_b_only), num_train_b, replace=False)
+
+ identifiers_a_train = [unique_a_only[i] for i in idx_a_train]
+ identifiers_b_train = [unique_b_only[i] for i in idx_b_train]
+
+ identifiers_a_val = [i for i in unique_a_only if i not in identifiers_a_train]
+ identifiers_b_val = [i for i in unique_b_only if i not in identifiers_b_train]
+
+ # fold 5 will be train on a and eval on val sets of a and b
+ splits.append({'train': [i for i in data_keys if i.split("_")[0] in identifiers_a_train],
+ 'val': [i for i in data_keys if i.split("_")[0] in identifiers_a_val] + [i for i in data_keys if
+ i.split("_")[
+ 0] in identifiers_b_val]})
+
+ # fold 6 will be train on b and eval on val sets of a and b
+ splits.append({'train': [i for i in data_keys if i.split("_")[0] in identifiers_b_train],
+ 'val': [i for i in data_keys if i.split("_")[0] in identifiers_a_val] + [i for i in data_keys if
+ i.split("_")[
+ 0] in identifiers_b_val]})
+
+ # fold 7 train on both, eval on both
+ splits.append({'train': [i for i in data_keys if i.split("_")[0] in identifiers_b_train] + [i for i in data_keys if i.split("_")[0] in identifiers_a_train],
+ 'val': [i for i in data_keys if i.split("_")[0] in identifiers_a_val] + [i for i in data_keys if
+ i.split("_")[
+ 0] in identifiers_b_val]})
+ save_pickle(splits, existing_splits)
+
+
+if __name__ == "__main__":
+ # this script will split 4d data from the M&Ms data set into 3d images for both, raw images and gt annotations.
+ # after this script you will be able to start a training on the M&Ms data.
+ # use this script as inspiration in case other data than M&Ms data is use for training.
+ #
+ # check also the comments at the END of the script for instructions on how to run the actual training after this
+ # script
+ #
+
+ # define a task ID for your experiment (I have choosen 114)
+ task_name = "Task679_heart_mnms"
+ # this is where the downloaded data from the M&Ms challenge shall be placed
+ raw_data_dir = "/media/full/tera2/data"
+ # set path to official ***M&Ms Dataset Information.xlsx*** file
+ df_path = "/media/full/tera2/data/challenges/mms/Training-corrected_original/M&Ms Dataset Information.xlsx"
+ # don't make changes here
+ folder_imagesTr = "imagesTr"
+ train_dir = os.path.join(raw_data_dir, task_name, folder_imagesTr)
+
+ # this is where our your splitted files WITH annotation will be stored. Dont make changes here. Otherwise nnUNet
+ # might have problems finding the training data later during the training process
+ out_dir = os.path.join(os.environ.get('nnUNet_raw_data_base'), 'nnUNet_raw_data', task_name)
+
+ files_raw, files_gt = get_mnms_data(data_root=train_dir)
+
+ filesTs, _ = get_mnms_data(data_root=train_dir)
+
+ split_path_raw_all_ts = os.path.join(raw_data_dir, task_name, "splitted_all_timesteps", folder_imagesTr,
+ "split_raw_images")
+ split_path_gt_all_ts = os.path.join(raw_data_dir, task_name, "splitted_all_timesteps", folder_imagesTr,
+ "split_annotation")
+ maybe_mkdir_p(split_path_raw_all_ts)
+ maybe_mkdir_p(split_path_gt_all_ts)
+
+ # for fast splitting of many patients use the following lines
+ # however keep in mind that these lines cause problems for some users.
+ # If problems occur use the code for loops below
+ # print("splitting raw 4d images into 3d images")
+ # split_4d_for_all_pat(files_raw, split_path_raw)
+ # print("splitting ground truth 4d into 3d files")
+ # split_4d_for_all_pat(files_gt, split_path_gt_all_ts)
+
+ print("splitting raw 4d images into 3d images")
+ for f in files_raw:
+ print("splitting {}".format(f))
+ split_4d_nifti(f, split_path_raw_all_ts)
+ print("splitting ground truth 4d into 3d files")
+ for gt in files_gt:
+ split_4d_nifti(gt, split_path_gt_all_ts)
+ print("splitting {}".format(gt))
+
+ print("prepared data will be saved at: {}".format(out_dir))
+ maybe_mkdir_p(join(out_dir, "imagesTr"))
+ maybe_mkdir_p(join(out_dir, "labelsTr"))
+
+ imagesTr_path = os.path.join(out_dir, "imagesTr")
+ labelsTr_path = os.path.join(out_dir, "labelsTr")
+ # only a small fraction of all timestep in the cardiac cycle possess gt annotation. These timestep will now be
+ # selected
+ select_annotated_frames_mms(split_path_raw_all_ts, imagesTr_path, add_zeros=True, is_gt=False, df_path=df_path)
+ select_annotated_frames_mms(split_path_gt_all_ts, labelsTr_path, add_zeros=False, is_gt=True, df_path=df_path)
+
+ labelsTr = subfiles(labelsTr_path)
+
+ # create a json file that will be needed by nnUNet to initiate the preprocessing process
+ json_dict = OrderedDict()
+ json_dict['name'] = "M&Ms"
+ json_dict['description'] = "short axis cardiac cine MRI segmentation"
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "Campello, Victor M et al. “Multi-Centre, Multi-Vendor and Multi-Disease Cardiac " \
+ "Segmentation: The M&Ms Challenge.” IEEE transactions on " \
+ "medical imaging vol. 40,12 (2021): 3543-3554. doi:10.1109/TMI.2021.3090082"
+ json_dict['licence'] = "see M&Ms challenge"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "MRI",
+ }
+ # labels differ for ACDC challenge
+ json_dict['labels'] = {
+ "0": "background",
+ "1": "LVBP",
+ "2": "LVM",
+ "3": "RV"
+ }
+ json_dict['numTraining'] = len(labelsTr)
+ json_dict['numTest'] = 0
+ json_dict['training'] = [{'image': "./imagesTr/%s" % i.split("/")[-1],
+ "label": "./labelsTr/%s" % i.split("/")[-1]} for i in labelsTr]
+ json_dict['test'] = []
+
+ save_json(json_dict, os.path.join(out_dir, "dataset.json"))
+
+ #
+ # now the data is ready to be preprocessed by the nnUNet
+ # the following steps are only needed if you want to reproduce the exact results from the MMS challenge
+ #
+
+
+ # then preprocess data and plan training.
+ # run in terminal
+ # nnUNet_plan_and_preprocess -t 114 --verify_dataset_integrity # for 2d
+ # nnUNet_plan_and_preprocess -t 114 --verify_dataset_integrity -pl3d ExperimentPlannerTargetSpacingForAnisoAxis # for 3d
+
+ # start training and stop it immediately to get a split.pkl file
+ # nnUNet_train 2d nnUNetTrainerV2_MMS 114 0
+
+ #
+ # then create custom splits as used for the final M&Ms submission
+ #
+
+ # in this file comment everything except for the following line
+ # create_custom_splits_for_experiments(out_dir)
+
+ # then start training with
+ #
+ # nnUNet_train 3d_fullres nnUNetTrainerV2_MMS Task114_heart_mnms -p nnUNetPlanstargetSpacingForAnisoAxis 0 # for 3d and fold 0
+ # and
+ # nnUNet_train 2d nnUNetTrainerV2_MMS Task114_heart_mnms 0 # for 2d and fold 0
+
+
diff --git a/nnunet/dataset_conversion/Task115_COVIDSegChallenge.py b/nnunet/dataset_conversion/Task115_COVIDSegChallenge.py
new file mode 100644
index 0000000000000000000000000000000000000000..36ab390f475bd6d1fe713eb3743ce1ca9bd87654
--- /dev/null
+++ b/nnunet/dataset_conversion/Task115_COVIDSegChallenge.py
@@ -0,0 +1,344 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import shutil
+import subprocess
+
+import SimpleITK as sitk
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+
+from nnunet.dataset_conversion.utils import generate_dataset_json
+from nnunet.paths import nnUNet_raw_data
+from nnunet.paths import preprocessing_output_dir
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+
+
+def increase_batch_size(plans_file: str, save_as: str, bs_factor: int):
+ a = load_pickle(plans_file)
+ stages = list(a['plans_per_stage'].keys())
+ for s in stages:
+ a['plans_per_stage'][s]['batch_size'] *= bs_factor
+ save_pickle(a, save_as)
+
+
+def prepare_submission(folder_in, folder_out):
+ nii = subfiles(folder_in, suffix='.gz', join=False)
+ maybe_mkdir_p(folder_out)
+ for n in nii:
+ i = n.split('-')[-1][:-10]
+ shutil.copy(join(folder_in, n), join(folder_out, i + '.nii.gz'))
+
+
+def get_ids_from_folder(folder):
+ cts = subfiles(folder, suffix='_ct.nii.gz', join=False)
+ ids = []
+ for c in cts:
+ ids.append(c.split('-')[-1][:-10])
+ return ids
+
+
+def postprocess_submission(folder_ct, folder_pred, folder_postprocessed, bbox_distance_to_seg_in_cm=7.5):
+ """
+ segment with lung mask, get bbox from that, use bbox to remove predictions in background
+
+ WE EXPERIMENTED WITH THAT ON THE VALIDATION SET AND FOUND THAT IT DOESN'T DO ANYTHING. NOT USED FOR TEST SET
+ """
+ # pip install git+https://github.com/JoHof/lungmask
+ cts = subfiles(folder_ct, suffix='_ct.nii.gz', join=False)
+ output_files = [i[:-10] + '_lungmask.nii.gz' for i in cts]
+
+ # run lungmask on everything
+ for i, o in zip(cts, output_files):
+ if not isfile(join(folder_ct, o)):
+ subprocess.call(['lungmask', join(folder_ct, i), join(folder_ct, o), '--modelname', 'R231CovidWeb'])
+
+ if not isdir(folder_postprocessed):
+ maybe_mkdir_p(folder_postprocessed)
+
+ ids = get_ids_from_folder(folder_ct)
+ for i in ids:
+ # find lungmask
+ lungmask_file = join(folder_ct, 'volume-covid19-A-' + i + '_lungmask.nii.gz')
+ if not isfile(lungmask_file):
+ raise RuntimeError('missing lung')
+ seg_file = join(folder_pred, 'volume-covid19-A-' + i + '_ct.nii.gz')
+ if not isfile(seg_file):
+ raise RuntimeError('missing seg')
+
+ lung_mask = sitk.GetArrayFromImage(sitk.ReadImage(lungmask_file))
+ seg_itk = sitk.ReadImage(seg_file)
+ seg = sitk.GetArrayFromImage(seg_itk)
+
+ where = np.argwhere(lung_mask != 0)
+ bbox = [
+ [min(where[:, 0]), max(where[:, 0])],
+ [min(where[:, 1]), max(where[:, 1])],
+ [min(where[:, 2]), max(where[:, 2])],
+ ]
+
+ spacing = np.array(seg_itk.GetSpacing())[::-1]
+ # print(bbox)
+ for dim in range(3):
+ sp = spacing[dim]
+ voxels_extend = max(int(np.ceil(bbox_distance_to_seg_in_cm / sp)), 1)
+ bbox[dim][0] = max(0, bbox[dim][0] - voxels_extend)
+ bbox[dim][1] = min(seg.shape[dim], bbox[dim][1] + voxels_extend)
+ # print(bbox)
+
+ seg_old = np.copy(seg)
+ seg[0:bbox[0][0], :, :] = 0
+ seg[bbox[0][1]:, :, :] = 0
+ seg[:, 0:bbox[1][0], :] = 0
+ seg[:, bbox[1][1]:, :] = 0
+ seg[:, :, 0:bbox[2][0]] = 0
+ seg[:, :, bbox[2][1]:] = 0
+ if np.any(seg_old != seg):
+ print('changed seg', i)
+ argwhere = np.argwhere(seg != seg_old)
+ print(argwhere[np.random.choice(len(argwhere), 10)])
+
+ seg_corr = sitk.GetImageFromArray(seg)
+ seg_corr.CopyInformation(seg_itk)
+ sitk.WriteImage(seg_corr, join(folder_postprocessed, 'volume-covid19-A-' + i + '_ct.nii.gz'))
+
+
+def manually_set_configurations():
+ """
+ ALSO NOT USED!
+ :return:
+ """
+ task115_dir = join(preprocessing_output_dir, convert_id_to_task_name(115))
+
+ ## larger patch size
+
+ # task115 3d_fullres default is:
+ """
+ {'batch_size': 2,
+ 'num_pool_per_axis': [2, 6, 6],
+ 'patch_size': array([ 28, 256, 256]),
+ 'median_patient_size_in_voxels': array([ 62, 512, 512]),
+ 'current_spacing': array([5. , 0.74199998, 0.74199998]),
+ 'original_spacing': array([5. , 0.74199998, 0.74199998]),
+ 'do_dummy_2D_data_aug': True,
+ 'pool_op_kernel_sizes': [[1, 2, 2], [1, 2, 2], [2, 2, 2], [2, 2, 2], [1, 2, 2], [1, 2, 2]],
+ 'conv_kernel_sizes': [[1, 3, 3], [1, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3], [3, 3, 3]]}
+ """
+ plans = load_pickle(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'))
+ fullres_stage = plans['plans_per_stage'][1]
+ fullres_stage['patch_size'] = np.array([ 64, 320, 320])
+ fullres_stage['num_pool_per_axis'] = [4, 6, 6]
+ fullres_stage['pool_op_kernel_sizes'] = [[1, 2, 2],
+ [1, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2]]
+ fullres_stage['conv_kernel_sizes'] = [[1, 3, 3],
+ [1, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3]]
+
+ save_pickle(plans, join(task115_dir, 'nnUNetPlansv2.1_custom_plans_3D.pkl'))
+
+ ## larger batch size
+ # (default for all 3d trainings is batch size 2)
+ increase_batch_size(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'), join(task115_dir, 'nnUNetPlansv2.1_bs3x_plans_3D.pkl'), 3)
+ increase_batch_size(join(task115_dir, 'nnUNetPlansv2.1_plans_3D.pkl'), join(task115_dir, 'nnUNetPlansv2.1_bs5x_plans_3D.pkl'), 5)
+
+ # residual unet
+ """
+ default is:
+ Out[7]:
+ {'batch_size': 2,
+ 'num_pool_per_axis': [2, 6, 5],
+ 'patch_size': array([ 28, 256, 224]),
+ 'median_patient_size_in_voxels': array([ 62, 512, 512]),
+ 'current_spacing': array([5. , 0.74199998, 0.74199998]),
+ 'original_spacing': array([5. , 0.74199998, 0.74199998]),
+ 'do_dummy_2D_data_aug': True,
+ 'pool_op_kernel_sizes': [[1, 1, 1],
+ [1, 2, 2],
+ [1, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [1, 2, 2],
+ [1, 2, 1]],
+ 'conv_kernel_sizes': [[1, 3, 3],
+ [1, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3]],
+ 'num_blocks_encoder': (1, 2, 3, 4, 4, 4, 4),
+ 'num_blocks_decoder': (1, 1, 1, 1, 1, 1)}
+ """
+ plans = load_pickle(join(task115_dir, 'nnUNetPlans_FabiansResUNet_v2.1_plans_3D.pkl'))
+ fullres_stage = plans['plans_per_stage'][1]
+ fullres_stage['patch_size'] = np.array([ 56, 256, 256])
+ fullres_stage['num_pool_per_axis'] = [3, 6, 6]
+ fullres_stage['pool_op_kernel_sizes'] = [[1, 1, 1],
+ [1, 2, 2],
+ [1, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [1, 2, 2]]
+ fullres_stage['conv_kernel_sizes'] = [[1, 3, 3],
+ [1, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3]]
+ save_pickle(plans, join(task115_dir, 'nnUNetPlans_FabiansResUNet_v2.1_custom_plans_3D.pkl'))
+
+
+def check_same(img1: str, img2: str):
+ """
+ checking initial vs corrected dataset
+ :param img1:
+ :param img2:
+ :return:
+ """
+ img1 = sitk.GetArrayFromImage(sitk.ReadImage(img1))
+ img2 = sitk.GetArrayFromImage(sitk.ReadImage(img2))
+ if not np.all([i==j for i, j in zip(img1.shape, img2.shape)]):
+ print('shape')
+ return False
+ else:
+ same = np.all(img1==img2)
+ if same: return True
+ else:
+ diffs = np.argwhere(img1!=img2)
+ print('content in', diffs.shape[0], 'voxels')
+ print('random disagreements:')
+ print(diffs[np.random.choice(len(diffs), min(3, diffs.shape[0]), replace=False)])
+ return False
+
+
+def check_dataset_same(dataset_old='/home/fabian/Downloads/COVID-19-20/Train',
+ dataset_new='/home/fabian/data/COVID-19-20_officialCorrected/COVID-19-20_v2/Train'):
+ """
+ :param dataset_old:
+ :param dataset_new:
+ :return:
+ """
+ cases = [i[:-10] for i in subfiles(dataset_new, suffix='_ct.nii.gz', join=False)]
+ for c in cases:
+ data_file = join(dataset_old, c + '_ct_corrDouble.nii.gz')
+ corrected_double = False
+ if not isfile(data_file):
+ data_file = join(dataset_old, c+'_ct.nii.gz')
+ else:
+ corrected_double = True
+ data_file_new = join(dataset_new, c+'_ct.nii.gz')
+
+ same = check_same(data_file, data_file_new)
+ if not same: print('data differs in case', c, '\n')
+
+ seg_file = join(dataset_old, c + '_seg_corrDouble_corrected.nii.gz')
+ if not isfile(seg_file):
+ seg_file = join(dataset_old, c + '_seg_corrected_auto.nii.gz')
+ if isfile(seg_file):
+ assert ~corrected_double
+ else:
+ seg_file = join(dataset_old, c + '_seg_corrected.nii.gz')
+ if isfile(seg_file):
+ assert ~corrected_double
+ else:
+ seg_file = join(dataset_old, c + '_seg_corrDouble.nii.gz')
+ if isfile(seg_file):
+ assert ~corrected_double
+ else:
+ seg_file = join(dataset_old, c + '_seg.nii.gz')
+ seg_file_new = join(dataset_new, c + '_seg.nii.gz')
+ same = check_same(seg_file, seg_file_new)
+ if not same: print('seg differs in case', c, '\n')
+
+
+if __name__ == '__main__':
+ # this is the folder containing the data as downloaded from https://covid-segmentation.grand-challenge.org/COVID-19-20/
+ # (zip file was decompressed!)
+ downloaded_data_dir = '/home/fabian/data/COVID-19-20_officialCorrected/COVID-19-20_v2/'
+
+ task_name = "Task115_COVIDSegChallenge"
+
+ target_base = join(nnUNet_raw_data, task_name)
+
+ target_imagesTr = join(target_base, "imagesTr")
+ target_imagesVal = join(target_base, "imagesVal")
+ target_labelsTr = join(target_base, "labelsTr")
+
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_imagesVal)
+ maybe_mkdir_p(target_labelsTr)
+
+ train_orig = join(downloaded_data_dir, "Train")
+
+ # convert training set
+ cases = [i[:-10] for i in subfiles(train_orig, suffix='_ct.nii.gz', join=False)]
+ for c in cases:
+ data_file = join(train_orig, c+'_ct.nii.gz')
+
+ # before there was the official corrected dataset we did some corrections of our own. These corrections were
+ # dropped when the official dataset was revised.
+ seg_file = join(train_orig, c + '_seg_corrected.nii.gz')
+ if not isfile(seg_file):
+ seg_file = join(train_orig, c + '_seg.nii.gz')
+
+ shutil.copy(data_file, join(target_imagesTr, c + "_0000.nii.gz"))
+ shutil.copy(seg_file, join(target_labelsTr, c + '.nii.gz'))
+
+ val_orig = join(downloaded_data_dir, "Validation")
+ cases = [i[:-10] for i in subfiles(val_orig, suffix='_ct.nii.gz', join=False)]
+ for c in cases:
+ data_file = join(val_orig, c + '_ct.nii.gz')
+
+ shutil.copy(data_file, join(target_imagesVal, c + "_0000.nii.gz"))
+
+ generate_dataset_json(
+ join(target_base, 'dataset.json'),
+ target_imagesTr,
+ None,
+ ("CT", ),
+ {0: 'background', 1: 'covid'},
+ task_name,
+ dataset_reference='https://covid-segmentation.grand-challenge.org/COVID-19-20/'
+ )
+
+ # performance summary (train set 5-fold cross-validation)
+
+ # baselines
+ # 3d_fullres nnUNetTrainerV2__nnUNetPlans_v2.1 0.7441
+ # 3d_lowres nnUNetTrainerV2__nnUNetPlans_v2.1 0.745
+
+ # models used for test set prediction
+ # 3d_fullres nnUNetTrainerV2_ResencUNet_DA3__nnUNetPlans_FabiansResUNet_v2.1 0.7543
+ # 3d_fullres nnUNetTrainerV2_ResencUNet__nnUNetPlans_FabiansResUNet_v2.1 0.7527
+ # 3d_lowres nnUNetTrainerV2_ResencUNet_DA3_BN__nnUNetPlans_FabiansResUNet_v2.1 0.7513
+ # 3d_fullres nnUNetTrainerV2_DA3_BN__nnUNetPlans_v2.1 0.7498
+ # 3d_fullres nnUNetTrainerV2_DA3__nnUNetPlans_v2.1 0.7532
+
+ # Test set prediction
+ # nnUNet_predict -i COVID-19-20_TestSet -o covid_testset_predictions/3d_fullres/nnUNetTrainerV2_ResencUNet_DA3__nnUNetPlans_FabiansResUNet_v2.1 -tr nnUNetTrainerV2_ResencUNet_DA3 -p nnUNetPlans_FabiansResUNet_v2.1 -m 3d_fullres -f 0 1 2 3 4 5 6 7 8 9 -t 115 -z
+ # nnUNet_predict -i COVID-19-20_TestSet -o covid_testset_predictions/3d_fullres/nnUNetTrainerV2_ResencUNet__nnUNetPlans_FabiansResUNet_v2.1 -tr nnUNetTrainerV2_ResencUNet -p nnUNetPlans_FabiansResUNet_v2.1 -m 3d_fullres -f 0 1 2 3 4 5 6 7 8 9 -t 115 -z
+ # nnUNet_predict -i COVID-19-20_TestSet -o covid_testset_predictions/3d_lowres/nnUNetTrainerV2_ResencUNet_DA3_BN__nnUNetPlans_FabiansResUNet_v2.1 -tr nnUNetTrainerV2_ResencUNet_DA3_BN -p nnUNetPlans_FabiansResUNet_v2.1 -m 3d_lowres -f 0 1 2 3 4 5 6 7 8 9 -t 115 -z
+ # nnUNet_predict -i COVID-19-20_TestSet -o covid_testset_predictions/3d_fullres/nnUNetTrainerV2_DA3_BN__nnUNetPlans_v2.1 -tr nnUNetTrainerV2_DA3_BN -m 3d_fullres -f 0 1 2 3 4 5 6 7 8 9 -t 115 -z
+ # nnUNet_predict -i COVID-19-20_TestSet -o covid_testset_predictions/3d_fullres/nnUNetTrainerV2_DA3__nnUNetPlans_v2.1 -tr nnUNetTrainerV2_DA3 -m 3d_fullres -f 0 1 2 3 4 5 6 7 8 9 -t 115 -z
+
+ # nnUNet_ensemble -f 3d_lowres/nnUNetTrainerV2_ResencUNet_DA3_BN__nnUNetPlans_FabiansResUNet_v2.1/ 3d_fullres/nnUNetTrainerV2_ResencUNet__nnUNetPlans_FabiansResUNet_v2.1/ 3d_fullres/nnUNetTrainerV2_ResencUNet_DA3__nnUNetPlans_FabiansResUNet_v2.1/ 3d_fullres/nnUNetTrainerV2_DA3_BN__nnUNetPlans_v2.1/ 3d_fullres/nnUNetTrainerV2_DA3__nnUNetPlans_v2.1/ -o ensembled
diff --git a/nnunet/dataset_conversion/Task120_Massachusetts_RoadSegm.py b/nnunet/dataset_conversion/Task120_Massachusetts_RoadSegm.py
new file mode 100644
index 0000000000000000000000000000000000000000..633a476c7d4d3a58a95081210732cc22b53442d1
--- /dev/null
+++ b/nnunet/dataset_conversion/Task120_Massachusetts_RoadSegm.py
@@ -0,0 +1,103 @@
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.dataset_conversion.utils import generate_dataset_json
+from nnunet.paths import nnUNet_raw_data, preprocessing_output_dir
+from nnunet.utilities.file_conversions import convert_2d_image_to_nifti
+
+if __name__ == '__main__':
+ """
+ nnU-Net was originally built for 3D images. It is also strongest when applied to 3D segmentation problems because a
+ large proportion of its design choices were built with 3D in mind. Also note that many 2D segmentation problems,
+ especially in the non-biomedical domain, may benefit from pretrained network architectures which nnU-Net does not
+ support.
+ Still, there is certainly a need for an out of the box segmentation solution for 2D segmentation problems. And
+ also on 2D segmentation tasks nnU-Net cam perform extremely well! We have, for example, won a 2D task in the cell
+ tracking challenge with nnU-Net (see our Nature Methods paper) and we have also successfully applied nnU-Net to
+ histopathological segmentation problems.
+ Working with 2D data in nnU-Net requires a small workaround in the creation of the dataset. Essentially, all images
+ must be converted to pseudo 3D images (so an image with shape (X, Y) needs to be converted to an image with shape
+ (1, X, Y). The resulting image must be saved in nifti format. Hereby it is important to set the spacing of the
+ first axis (the one with shape 1) to a value larger than the others. If you are working with niftis anyways, then
+ doing this should be easy for you. This example here is intended for demonstrating how nnU-Net can be used with
+ 'regular' 2D images. We selected the massachusetts road segmentation dataset for this because it can be obtained
+ easily, it comes with a good amount of training cases but is still not too large to be difficult to handle.
+ """
+
+ # download dataset from https://www.kaggle.com/insaff/massachusetts-roads-dataset
+ # extract the zip file, then set the following path according to your system:
+ base = '/media/fabian/data/road_segmentation_ideal'
+ # this folder should have the training and testing subfolders
+
+ # now start the conversion to nnU-Net:
+ task_name = 'Task120_MassRoadsSeg'
+ target_base = join(nnUNet_raw_data, task_name)
+ target_imagesTr = join(target_base, "imagesTr")
+ target_imagesTs = join(target_base, "imagesTs")
+ target_labelsTs = join(target_base, "labelsTs")
+ target_labelsTr = join(target_base, "labelsTr")
+
+ maybe_mkdir_p(target_imagesTr)
+ maybe_mkdir_p(target_labelsTs)
+ maybe_mkdir_p(target_imagesTs)
+ maybe_mkdir_p(target_labelsTr)
+
+ # convert the training examples. Not all training images have labels, so we just take the cases for which there are
+ # labels
+ labels_dir_tr = join(base, 'training', 'output')
+ images_dir_tr = join(base, 'training', 'input')
+ training_cases = subfiles(labels_dir_tr, suffix='.png', join=False)
+ for t in training_cases:
+ unique_name = t[:-4] # just the filename with the extension cropped away, so img-2.png becomes img-2 as unique_name
+ input_segmentation_file = join(labels_dir_tr, t)
+ input_image_file = join(images_dir_tr, t)
+
+ output_image_file = join(target_imagesTr, unique_name) # do not specify a file ending! This will be done for you
+ output_seg_file = join(target_labelsTr, unique_name) # do not specify a file ending! This will be done for you
+
+ # this utility will convert 2d images that can be read by skimage.io.imread to nifti. You don't need to do anything.
+ # if this throws an error for your images, please just look at the code for this function and adapt it to your needs
+ convert_2d_image_to_nifti(input_image_file, output_image_file, is_seg=False)
+
+ # the labels are stored as 0: background, 255: road. We need to convert the 255 to 1 because nnU-Net expects
+ # the labels to be consecutive integers. This can be achieved with setting a transform
+ convert_2d_image_to_nifti(input_segmentation_file, output_seg_file, is_seg=True,
+ transform=lambda x: (x == 255).astype(int))
+
+ # now do the same for the test set
+ labels_dir_ts = join(base, 'testing', 'output')
+ images_dir_ts = join(base, 'testing', 'input')
+ testing_cases = subfiles(labels_dir_ts, suffix='.png', join=False)
+ for ts in testing_cases:
+ unique_name = ts[:-4]
+ input_segmentation_file = join(labels_dir_ts, ts)
+ input_image_file = join(images_dir_ts, ts)
+
+ output_image_file = join(target_imagesTs, unique_name)
+ output_seg_file = join(target_labelsTs, unique_name)
+
+ convert_2d_image_to_nifti(input_image_file, output_image_file, is_seg=False)
+ convert_2d_image_to_nifti(input_segmentation_file, output_seg_file, is_seg=True,
+ transform=lambda x: (x == 255).astype(int))
+
+ # finally we can call the utility for generating a dataset.json
+ generate_dataset_json(join(target_base, 'dataset.json'), target_imagesTr, target_imagesTs, ('Red', 'Green', 'Blue'),
+ labels={0: 'background', 1: 'street'}, dataset_name=task_name, license='hands off!')
+
+ """
+ once this is completed, you can use the dataset like any other nnU-Net dataset. Note that since this is a 2D
+ dataset there is no need to run preprocessing for 3D U-Nets. You should therefore run the
+ `nnUNet_plan_and_preprocess` command like this:
+
+ > nnUNet_plan_and_preprocess -t 120 -pl3d None
+
+ once that is completed, you can run the trainings as follows:
+ > nnUNet_train 2d nnUNetTrainerV2 120 FOLD
+
+ (where fold is again 0, 1, 2, 3 and 4 - 5-fold cross validation)
+
+ there is no need to run nnUNet_find_best_configuration because there is only one model to choose from.
+ Note that without running nnUNet_find_best_configuration, nnU-Net will not have determined a postprocessing
+ for the whole cross-validation. Spoiler: it will determine not to run postprocessing anyways. If you are using
+ a different 2D dataset, you can make nnU-Net determine the postprocessing by using the
+ `nnUNet_determine_postprocessing` command
+ """
diff --git a/nnunet/dataset_conversion/Task135_KiTS2021.py b/nnunet/dataset_conversion/Task135_KiTS2021.py
new file mode 100644
index 0000000000000000000000000000000000000000..eee6672f79d50068b12b7126e4b414c4ae5b4490
--- /dev/null
+++ b/nnunet/dataset_conversion/Task135_KiTS2021.py
@@ -0,0 +1,49 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+
+from nnunet.paths import nnUNet_raw_data
+from nnunet.dataset_conversion.utils import generate_dataset_json
+
+if __name__ == '__main__':
+ # this is the data folder from the kits21 github repository, see https://github.com/neheller/kits21
+ kits_data_dir = '/home/fabian/git_repos/kits21/kits21/data'
+
+ # This script uses the majority voted segmentation as ground truth
+ kits_segmentation_filename = 'aggregated_MAJ_seg.nii.gz'
+
+ # Arbitrary task id. This is just to ensure each dataset ha a unique number. Set this to whatever ([0-999]) you
+ # want
+ task_id = 135
+ task_name = "KiTS2021"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(labelstr)
+
+ case_ids = subdirs(kits_data_dir, prefix='case_', join=False)
+ for c in case_ids:
+ if isfile(join(kits_data_dir, c, kits_segmentation_filename)):
+ shutil.copy(join(kits_data_dir, c, kits_segmentation_filename), join(labelstr, c + '.nii.gz'))
+ shutil.copy(join(kits_data_dir, c, 'imaging.nii.gz'), join(imagestr, c + '_0000.nii.gz'))
+
+ generate_dataset_json(join(out_base, 'dataset.json'),
+ imagestr,
+ None,
+ ('CT',),
+ {
+ 0: 'background',
+ 1: "kidney",
+ 2: "tumor",
+ 3: "cyst",
+ },
+ task_name,
+ license='see https://kits21.kits-challenge.org/participate#download-block',
+ dataset_description='see https://kits21.kits-challenge.org/',
+ dataset_reference='https://www.sciencedirect.com/science/article/abs/pii/S1361841520301857, '
+ 'https://kits21.kits-challenge.org/',
+ dataset_release='0')
diff --git a/nnunet/dataset_conversion/Task154_RibFrac_multi_label.py b/nnunet/dataset_conversion/Task154_RibFrac_multi_label.py
new file mode 100755
index 0000000000000000000000000000000000000000..c77e13e2ba9b5d27c8240b83107fd4516bceb14f
--- /dev/null
+++ b/nnunet/dataset_conversion/Task154_RibFrac_multi_label.py
@@ -0,0 +1,172 @@
+import SimpleITK as sitk
+from natsort import natsorted
+import numpy as np
+from pathlib import Path
+import pandas as pd
+from collections import defaultdict
+from shutil import copyfile
+import os
+from os.path import join
+from tqdm import tqdm
+import gc
+import multiprocessing as mp
+from nnunet.dataset_conversion.utils import generate_dataset_json
+from functools import partial
+
+
+def preprocess_dataset(dataset_load_path, dataset_save_path, pool):
+ train_image_load_path = join(dataset_load_path, "imagesTr")
+ train_mask_load_path = join(dataset_load_path, "labelsTr")
+ test_image_load_path = join(dataset_load_path, "imagesTs")
+
+ ribfrac_train_info_1_path = join(dataset_load_path, "ribfrac-train-info-1.csv")
+ ribfrac_train_info_2_path = join(dataset_load_path, "ribfrac-train-info-2.csv")
+ ribfrac_val_info_path = join(dataset_load_path, "ribfrac-val-info.csv")
+
+ train_image_save_path = join(dataset_save_path, "imagesTr")
+ train_mask_save_path = join(dataset_save_path, "labelsTr")
+ test_image_save_path = join(dataset_save_path, "imagesTs")
+ Path(train_image_save_path).mkdir(parents=True, exist_ok=True)
+ Path(train_mask_save_path).mkdir(parents=True, exist_ok=True)
+ Path(test_image_save_path).mkdir(parents=True, exist_ok=True)
+
+ meta_data = preprocess_csv(ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path)
+ preprocess_train(train_image_load_path, train_mask_load_path, meta_data, dataset_save_path, pool)
+ preprocess_test(test_image_load_path, dataset_save_path)
+
+
+def preprocess_csv(ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path):
+ print("Processing csv...")
+ meta_data = defaultdict(list)
+ for csv_path in [ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path]:
+ df = pd.read_csv(csv_path)
+ for index, row in df.iterrows():
+ name = row["public_id"]
+ instance = row["label_id"]
+ class_label = row["label_code"]
+ meta_data[name].append({"instance": instance, "class_label": class_label})
+ print("Finished csv processing.")
+ return meta_data
+
+
+def preprocess_train(image_path, mask_path, meta_data, save_path, pool):
+ print("Processing train data...")
+ pool.map(partial(preprocess_train_single, image_path=image_path, mask_path=mask_path, meta_data=meta_data, save_path=save_path), meta_data.keys())
+ print("Finished processing train data.")
+
+
+def preprocess_train_single(name, image_path, mask_path, meta_data, save_path):
+ id = int(name[7:])
+ image, _, _, _ = load_image(join(image_path, name + "-image.nii.gz"), return_meta=True, is_seg=False)
+ instance_seg_mask, spacing, _, _ = load_image(join(mask_path, name + "-label.nii.gz"), return_meta=True, is_seg=True)
+ semantic_seg_mask = np.zeros_like(instance_seg_mask, dtype=int)
+ for entry in meta_data[name]:
+ semantic_seg_mask[instance_seg_mask == entry["instance"]] = entry["class_label"]
+ semantic_seg_mask[semantic_seg_mask == -1] = 5 # Set ignore label to 5
+ save_image(join(save_path, "imagesTr/RibFrac_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False)
+ save_image(join(save_path, "labelsTr/RibFrac_" + str(id).zfill(4) + ".nii.gz"), semantic_seg_mask, spacing=spacing, is_seg=True)
+
+
+def preprocess_test(load_test_image_dir, save_path):
+ print("Processing test data...")
+ filenames = load_filenames(load_test_image_dir)
+ for filename in tqdm(filenames):
+ id = int(os.path.basename(filename)[8:-13])
+ copyfile(filename, join(save_path, "imagesTs/RibFrac_" + str(id).zfill(4) + "_0000.nii.gz"))
+ print("Finished processing test data.")
+
+
+def load_filenames(img_dir, extensions=None):
+ _img_dir = fix_path(img_dir)
+ img_filenames = []
+
+ for file in os.listdir(_img_dir):
+ if extensions is None or file.endswith(extensions):
+ img_filenames.append(_img_dir + file)
+ img_filenames = np.asarray(img_filenames)
+ img_filenames = natsorted(img_filenames)
+
+ return img_filenames
+
+
+def fix_path(path):
+ if path[-1] != "/":
+ path += "/"
+ return path
+
+
+def load_image(filepath, return_meta=False, is_seg=False):
+ image = sitk.ReadImage(filepath)
+ image_np = sitk.GetArrayFromImage(image)
+
+ if is_seg:
+ image_np = np.rint(image_np)
+ image_np = image_np.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ if not return_meta:
+ return image_np
+ else:
+ spacing = image.GetSpacing()
+ keys = image.GetMetaDataKeys()
+ header = {key:image.GetMetaData(key) for key in keys}
+ affine = None # How do I get the affine transform with SimpleITK? With NiBabel it is just image.affine
+ return image_np, spacing, affine, header
+
+
+def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False):
+ if is_seg:
+ image = np.rint(image)
+ image = image.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ image = sitk.GetImageFromArray(image)
+
+ if header is not None:
+ [image.SetMetaData(key, header[key]) for key in header.keys()]
+
+ if spacing is not None:
+ image.SetSpacing(spacing)
+
+ if affine is not None:
+ pass # How do I set the affine transform with SimpleITK? With NiBabel it is just nib.Nifti1Image(img, affine=affine, header=header)
+
+ if mp_pool is None:
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+ else:
+ mp_pool.apply_async(_save, args=(filename, image, free_mem,))
+ if free_mem:
+ del image
+ gc.collect()
+
+
+def _save(filename, image, free_mem):
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+
+
+if __name__ == "__main__":
+ # Note: Due to a bug in SimpleITK 2.1.x a version of SimpleITK < 2.1.0 is required for loading images. Further, we can't copy the images and masks, but have to load them and resample both to the same spacing.
+ # Conversion instructions:
+ # 1. All sets, parts and CSVs need to be downloaded from https://ribfrac.grand-challenge.org/dataset/
+ # 2. Unzip ribfrac-train-images-1.zip (will be unzipped as Part1) and ribfrac-train-images-2.zip (will be unzipped as Part2), move content from Part2 to Part1 and rename the folder to imagesTr
+ # 3. Unzip ribfrac-train-labels-1.zip (will be unzipped as Part1) and ribfrac-train-labels-2.zip (will be unzipped as Part2), move content from Part2 to Part1 and rename the folder to labelsTr
+ # 4. Unzip ribfrac-val-images.zip and add content to imagesTr, repeat with ribfrac-val-labels.zip
+ # 5. Unzip ribfrac-test-images.zip and rename it to imagesTs
+
+ pool = mp.Pool(processes=20)
+
+ dataset_load_path = "/home/k539i/Documents/network_drives/E132-Projekte/Projects/2021_Gotkowski_RibFrac_RibSeg/original/RibFrac/"
+ dataset_save_path = "/home/k539i/Documents/network_drives/E132-Projekte/Projects/2021_Gotkowski_RibFrac_RibSeg/preprocessed/Task154_RibFrac_multi_label/"
+ preprocess_dataset(dataset_load_path, dataset_save_path, pool)
+
+ print("Still saving images in background...")
+ pool.close()
+ pool.join()
+ print("All tasks finished.")
+
+ labels = {0: "background", 1: "displaced_rib_fracture", 2: "non_displaced_rib_fracture", 3: "buckle_rib_fracture", 4: "segmental_rib_fracture", 5: "unidentified_rib_fracture"}
+ generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), labels, "Task154_RibFrac_multi_label")
diff --git a/nnunet/dataset_conversion/Task155_RibFrac_binary.py b/nnunet/dataset_conversion/Task155_RibFrac_binary.py
new file mode 100755
index 0000000000000000000000000000000000000000..a69625b5eb8618f4537fa80435359f040278ad62
--- /dev/null
+++ b/nnunet/dataset_conversion/Task155_RibFrac_binary.py
@@ -0,0 +1,174 @@
+import SimpleITK as sitk
+from natsort import natsorted
+import numpy as np
+from pathlib import Path
+import pandas as pd
+from collections import defaultdict
+from shutil import copyfile
+import os
+from os.path import join
+from tqdm import tqdm
+import gc
+import multiprocessing as mp
+from nnunet.dataset_conversion.utils import generate_dataset_json
+from functools import partial
+
+
+def preprocess_dataset(dataset_load_path, dataset_save_path, pool):
+ train_image_load_path = join(dataset_load_path, "imagesTr")
+ train_mask_load_path = join(dataset_load_path, "labelsTr")
+ test_image_load_path = join(dataset_load_path, "imagesTs")
+
+ ribfrac_train_info_1_path = join(dataset_load_path, "ribfrac-train-info-1.csv")
+ ribfrac_train_info_2_path = join(dataset_load_path, "ribfrac-train-info-2.csv")
+ ribfrac_val_info_path = join(dataset_load_path, "ribfrac-val-info.csv")
+
+ train_image_save_path = join(dataset_save_path, "imagesTr")
+ train_mask_save_path = join(dataset_save_path, "labelsTr")
+ test_image_save_path = join(dataset_save_path, "imagesTs")
+ Path(train_image_save_path).mkdir(parents=True, exist_ok=True)
+ Path(train_mask_save_path).mkdir(parents=True, exist_ok=True)
+ Path(test_image_save_path).mkdir(parents=True, exist_ok=True)
+
+ meta_data = preprocess_csv(ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path)
+ preprocess_train(train_image_load_path, train_mask_load_path, meta_data, dataset_save_path, pool)
+ preprocess_test(test_image_load_path, dataset_save_path)
+
+
+def preprocess_csv(ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path):
+ print("Processing csv...")
+ meta_data = defaultdict(list)
+ for csv_path in [ribfrac_train_info_1_path, ribfrac_train_info_2_path, ribfrac_val_info_path]:
+ df = pd.read_csv(csv_path)
+ for index, row in df.iterrows():
+ name = row["public_id"]
+ instance = row["label_id"]
+ class_label = row["label_code"]
+ meta_data[name].append({"instance": instance, "class_label": class_label})
+ print("Finished csv processing.")
+ return meta_data
+
+
+def preprocess_train(image_path, mask_path, meta_data, save_path, pool):
+ print("Processing train data...")
+ pool.map(partial(preprocess_train_single, image_path=image_path, mask_path=mask_path, meta_data=meta_data, save_path=save_path), meta_data.keys())
+ print("Finished processing train data.")
+
+
+def preprocess_train_single(name, image_path, mask_path, meta_data, save_path):
+ id = int(name[7:])
+ image, _, _, _ = load_image(join(image_path, name + "-image.nii.gz"), return_meta=True, is_seg=False)
+ instance_seg_mask, spacing, _, _ = load_image(join(mask_path, name + "-label.nii.gz"), return_meta=True, is_seg=True)
+ semantic_seg_mask = np.zeros_like(instance_seg_mask, dtype=int)
+ for entry in meta_data[name]:
+ class_label = entry["class_label"]
+ if class_label > 0:
+ class_label = 1
+ semantic_seg_mask[instance_seg_mask == entry["instance"]] = class_label
+ save_image(join(save_path, "imagesTr/RibFrac_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False)
+ save_image(join(save_path, "labelsTr/RibFrac_" + str(id).zfill(4) + ".nii.gz"), semantic_seg_mask, spacing=spacing, is_seg=True)
+
+
+def preprocess_test(load_test_image_dir, save_path):
+ print("Processing test data...")
+ filenames = load_filenames(load_test_image_dir)
+ for filename in tqdm(filenames):
+ id = int(os.path.basename(filename)[8:-13])
+ copyfile(filename, join(save_path, "imagesTs/RibFrac_" + str(id).zfill(4) + "_0000.nii.gz"))
+ print("Finished processing test data.")
+
+
+def load_filenames(img_dir, extensions=None):
+ _img_dir = fix_path(img_dir)
+ img_filenames = []
+
+ for file in os.listdir(_img_dir):
+ if extensions is None or file.endswith(extensions):
+ img_filenames.append(_img_dir + file)
+ img_filenames = np.asarray(img_filenames)
+ img_filenames = natsorted(img_filenames)
+
+ return img_filenames
+
+
+def fix_path(path):
+ if path[-1] != "/":
+ path += "/"
+ return path
+
+
+def load_image(filepath, return_meta=False, is_seg=False):
+ image = sitk.ReadImage(filepath)
+ image_np = sitk.GetArrayFromImage(image)
+
+ if is_seg:
+ image_np = np.rint(image_np)
+ image_np = image_np.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ if not return_meta:
+ return image_np
+ else:
+ spacing = image.GetSpacing()
+ keys = image.GetMetaDataKeys()
+ header = {key:image.GetMetaData(key) for key in keys}
+ affine = None # How do I get the affine transform with SimpleITK? With NiBabel it is just image.affine
+ return image_np, spacing, affine, header
+
+
+def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False):
+ if is_seg:
+ image = np.rint(image)
+ image = image.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ image = sitk.GetImageFromArray(image)
+
+ if header is not None:
+ [image.SetMetaData(key, header[key]) for key in header.keys()]
+
+ if spacing is not None:
+ image.SetSpacing(spacing)
+
+ if affine is not None:
+ pass # How do I set the affine transform with SimpleITK? With NiBabel it is just nib.Nifti1Image(img, affine=affine, header=header)
+
+ if mp_pool is None:
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+ else:
+ mp_pool.apply_async(_save, args=(filename, image, free_mem,))
+ if free_mem:
+ del image
+ gc.collect()
+
+
+def _save(filename, image, free_mem):
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+
+
+if __name__ == "__main__":
+ # Note: Due to a bug in SimpleITK 2.1.x a version of SimpleITK < 2.1.0 is required for loading images. Further, we can't copy the images and masks, but have to load them and resample both to the same spacing.
+ # Conversion instructions:
+ # 1. All sets, parts and CSVs need to be downloaded from https://ribfrac.grand-challenge.org/dataset/
+ # 2. Unzip ribfrac-train-images-1.zip (will be unzipped as Part1) and ribfrac-train-images-2.zip (will be unzipped as Part2), move content from Part2 to Part1 and rename the folder to imagesTr
+ # 3. Unzip ribfrac-train-labels-1.zip (will be unzipped as Part1) and ribfrac-train-labels-2.zip (will be unzipped as Part2), move content from Part2 to Part1 and rename the folder to labelsTr
+ # 4. Unzip ribfrac-val-images.zip and add content to imagesTr, repeat with ribfrac-val-labels.zip
+ # 5. Unzip ribfrac-test-images.zip and rename it to imagesTs
+
+ pool = mp.Pool(processes=20)
+
+ dataset_load_path = "/home/k539i/Documents/network_drives/E132-Projekte/Projects/2021_Gotkowski_RibFrac_RibSeg/original/RibFrac/"
+ dataset_save_path = "/home/k539i/Documents/network_drives/E132-Projekte/Projects/2021_Gotkowski_RibFrac_RibSeg/preprocessed/Task155_RibFrac_binary/"
+ preprocess_dataset(dataset_load_path, dataset_save_path, pool)
+
+ print("Still saving images in background...")
+ pool.close()
+ pool.join()
+ print("All tasks finished.")
+
+ labels = {0: "background", 1: "fracture"}
+ generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), labels, "Task155_RibFrac_binary")
diff --git a/nnunet/dataset_conversion/Task156_RibSeg.py b/nnunet/dataset_conversion/Task156_RibSeg.py
new file mode 100755
index 0000000000000000000000000000000000000000..678cf034f9faa07af5f4aeba75aca41826c6b192
--- /dev/null
+++ b/nnunet/dataset_conversion/Task156_RibSeg.py
@@ -0,0 +1,140 @@
+from natsort import natsorted
+import numpy as np
+from pathlib import Path
+import os
+from os.path import join
+from nnunet.dataset_conversion.utils import generate_dataset_json
+import SimpleITK as sitk
+import gc
+import multiprocessing as mp
+from functools import partial
+
+
+def preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool):
+ mask_load_path = join(ribseg_load_path, "labelsTr")
+
+ train_image_save_path = join(dataset_save_path, "imagesTr")
+ train_mask_save_path = join(dataset_save_path, "labelsTr")
+ test_image_save_path = join(dataset_save_path, "imagesTs")
+ test_labels_save_path = join(dataset_save_path, "labelsTs")
+ Path(train_image_save_path).mkdir(parents=True, exist_ok=True)
+ Path(train_mask_save_path).mkdir(parents=True, exist_ok=True)
+ Path(test_image_save_path).mkdir(parents=True, exist_ok=True)
+ Path(test_labels_save_path).mkdir(parents=True, exist_ok=True)
+
+ mask_filenames = load_filenames(mask_load_path)
+ pool.map(partial(preprocess_single, image_load_path=ribfrac_load_path), mask_filenames)
+
+
+def preprocess_single(filename, image_load_path):
+ name = os.path.basename(filename)
+ if "-cl.nii.gz" in name:
+ return
+ id = int(name.split("-")[0][7:])
+ image_set = "imagesTr"
+ mask_set = "labelsTr"
+ if id > 500:
+ image_set = "imagesTs"
+ mask_set = "labelsTs"
+ image, _, _, _ = load_image(join(image_load_path, image_set, "RibFrac{}-image.nii.gz".format(id)), return_meta=True, is_seg=False)
+ mask, spacing, _, _ = load_image(filename, return_meta=True, is_seg=True)
+ save_image(join(dataset_save_path, image_set, "RibSeg_" + str(id).zfill(4) + "_0000.nii.gz"), image, spacing=spacing, is_seg=False)
+ save_image(join(dataset_save_path, mask_set, "RibSeg_" + str(id).zfill(4) + ".nii.gz"), mask, spacing=spacing, is_seg=True)
+
+
+def load_filenames(img_dir, extensions=None):
+ _img_dir = fix_path(img_dir)
+ img_filenames = []
+
+ for file in os.listdir(_img_dir):
+ if extensions is None or file.endswith(extensions):
+ img_filenames.append(_img_dir + file)
+ img_filenames = np.asarray(img_filenames)
+ img_filenames = natsorted(img_filenames)
+
+ return img_filenames
+
+
+def fix_path(path):
+ if path[-1] != "/":
+ path += "/"
+ return path
+
+
+def load_image(filepath, return_meta=False, is_seg=False):
+ image = sitk.ReadImage(filepath)
+ image_np = sitk.GetArrayFromImage(image)
+
+ if is_seg:
+ image_np = np.rint(image_np)
+ image_np = image_np.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ if not return_meta:
+ return image_np
+ else:
+ spacing = image.GetSpacing()
+ keys = image.GetMetaDataKeys()
+ header = {key:image.GetMetaData(key) for key in keys}
+ affine = None # How do I get the affine transform with SimpleITK? With NiBabel it is just image.affine
+ return image_np, spacing, affine, header
+
+
+def save_image(filename, image, spacing=None, affine=None, header=None, is_seg=False, mp_pool=None, free_mem=False):
+ if is_seg:
+ image = np.rint(image)
+ image = image.astype(np.int8) # In special cases segmentations can contain negative labels, so no np.uint8
+
+ image = sitk.GetImageFromArray(image)
+
+ if header is not None:
+ [image.SetMetaData(key, header[key]) for key in header.keys()]
+
+ if spacing is not None:
+ image.SetSpacing(spacing)
+
+ if affine is not None:
+ pass # How do I set the affine transform with SimpleITK? With NiBabel it is just nib.Nifti1Image(img, affine=affine, header=header)
+
+ if mp_pool is None:
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+ else:
+ mp_pool.apply_async(_save, args=(filename, image, free_mem,))
+ if free_mem:
+ del image
+ gc.collect()
+
+
+def _save(filename, image, free_mem):
+ sitk.WriteImage(image, filename)
+ if free_mem:
+ del image
+ gc.collect()
+
+
+if __name__ == "__main__":
+ # Note: Due to a bug in SimpleITK 2.1.x a version of SimpleITK < 2.1.0 is required for loading images. Further, we can't copy the images and masks, but have to load them and resample both to the same spacing.
+ # Conversion instructions:
+ # 1. All images from both training and validation set of the RibFrac dataset need to be downloaded from https://ribfrac.grand-challenge.org/dataset/ into a new folder named RibFrac
+ # 2. The RibSeg masks need to be downloaded from https://zenodo.org/record/5336592 into a new folder named RibSeg
+ # 3. Follow unpacking instruction for the RibFrac dataset as in Task154_RibFrac
+ # 4. Unzip RibSeg_490_nii.zip from the RibSeg dataset and rename the folder labelsTr
+
+ ribfrac_load_path = "/home/k539i/Documents/datasets/original/RibFrac/"
+ ribseg_load_path = "/home/k539i/Documents/datasets/original/RibSeg/"
+ dataset_save_path = "/home/k539i/Documents/datasets/preprocessed/Task156_RibSeg/"
+
+ max_imagesTr_id = 500
+
+ pool = mp.Pool(processes=20)
+
+ preprocess_dataset(ribfrac_load_path, ribseg_load_path, dataset_save_path, pool)
+
+ print("Still saving images in background...")
+ pool.close()
+ pool.join()
+ print("All tasks finished.")
+
+ generate_dataset_json(join(dataset_save_path, 'dataset.json'), join(dataset_save_path, "imagesTr"), None, ('CT',), {0: 'bg', 1: 'rib'}, "Task156_RibSeg")
diff --git a/nnunet/dataset_conversion/Task159_MyoPS2020.py b/nnunet/dataset_conversion/Task159_MyoPS2020.py
new file mode 100644
index 0000000000000000000000000000000000000000..7964721d839324a8294f58fc9be54d56cc938204
--- /dev/null
+++ b/nnunet/dataset_conversion/Task159_MyoPS2020.py
@@ -0,0 +1,106 @@
+import SimpleITK
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+
+import SimpleITK as sitk
+from nnunet.paths import nnUNet_raw_data
+from nnunet.dataset_conversion.utils import generate_dataset_json
+from nnunet.utilities.sitk_stuff import copy_geometry
+
+
+def convert_labels_to_nnunet(source_nifti: str, target_nifti: str):
+ img = sitk.ReadImage(source_nifti)
+ img_npy = sitk.GetArrayFromImage(img)
+ nnunet_seg = np.zeros(img_npy.shape, dtype=np.uint8)
+ # why are they not using normal labels and instead use random numbers???
+ nnunet_seg[img_npy == 500] = 1 # left ventricular (LV) blood pool (500)
+ nnunet_seg[img_npy == 600] = 2 # right ventricular blood pool (600)
+ nnunet_seg[img_npy == 200] = 3 # LV normal myocardium (200)
+ nnunet_seg[img_npy == 1220] = 4 # LV myocardial edema (1220)
+ nnunet_seg[img_npy == 2221] = 5 # LV myocardial scars (2221)
+ nnunet_seg_itk = sitk.GetImageFromArray(nnunet_seg)
+ nnunet_seg_itk = copy_geometry(nnunet_seg_itk, img)
+ sitk.WriteImage(nnunet_seg_itk, target_nifti)
+
+
+def convert_labels_back_to_myops(source_nifti: str, target_nifti: str):
+ nnunet_itk = sitk.ReadImage(source_nifti)
+ nnunet_npy = sitk.GetArrayFromImage(nnunet_itk)
+ myops_seg = np.zeros(nnunet_npy.shape, dtype=np.uint8)
+ # why are they not using normal labels and instead use random numbers???
+ myops_seg[nnunet_npy == 1] = 500 # left ventricular (LV) blood pool (500)
+ myops_seg[nnunet_npy == 2] = 600 # right ventricular blood pool (600)
+ myops_seg[nnunet_npy == 3] = 200 # LV normal myocardium (200)
+ myops_seg[nnunet_npy == 4] = 1220 # LV myocardial edema (1220)
+ myops_seg[nnunet_npy == 5] = 2221 # LV myocardial scars (2221)
+ myops_seg_itk = sitk.GetImageFromArray(myops_seg)
+ myops_seg_itk = copy_geometry(myops_seg_itk, nnunet_itk)
+ sitk.WriteImage(myops_seg_itk, target_nifti)
+
+
+if __name__ == '__main__':
+ # this is where we extracted all the archives. This folder must have the subfolders test20, train25,
+ # train25_myops_gd. We do not use test_data_gd because the test GT is encoded and cannot be used as it is
+ base = '/home/fabian/Downloads/MyoPS 2020 Dataset'
+
+ # Arbitrary task id. This is just to ensure each dataset ha a unique number. Set this to whatever ([0-999]) you
+ # want
+ task_id = 159
+ task_name = "MyoPS2020"
+
+ foldername = "Task%03.0d_%s" % (task_id, task_name)
+
+ # setting up nnU-Net folders
+ out_base = join(nnUNet_raw_data, foldername)
+ imagestr = join(out_base, "imagesTr")
+ imagests = join(out_base, "imagesTs")
+ labelstr = join(out_base, "labelsTr")
+ maybe_mkdir_p(imagestr)
+ maybe_mkdir_p(imagests)
+ maybe_mkdir_p(labelstr)
+
+ imagestr_source = join(base, 'train25')
+ imagests_source = join(base, 'test20')
+ labelstr_source = join(base, 'train25_myops_gd')
+
+ # convert training set
+ nii_files = nifti_files(imagestr_source, join=False)
+ # remove their modality identifier. Conveniently it's always 2 characters. np.unique to get the identifiers
+ identifiers = np.unique([i[:-len('_C0.nii.gz')] for i in nii_files])
+ for i in identifiers:
+ shutil.copy(join(imagestr_source, i + "_C0.nii.gz"), join(imagestr, i + '_0000.nii.gz'))
+ shutil.copy(join(imagestr_source, i + "_DE.nii.gz"), join(imagestr, i + '_0001.nii.gz'))
+ shutil.copy(join(imagestr_source, i + "_T2.nii.gz"), join(imagestr, i + '_0002.nii.gz'))
+ convert_labels_to_nnunet(join(labelstr_source, i + '_gd.nii.gz'), join(labelstr, i + '.nii.gz'))
+
+ # test set
+ nii_files = nifti_files(imagests_source, join=False)
+ # remove their modality identifier. Conveniently it's always 2 characters. np.unique to get the identifiers
+ identifiers = np.unique([i[:-len('_C0.nii.gz')] for i in nii_files])
+ for i in identifiers:
+ shutil.copy(join(imagests_source, i + "_C0.nii.gz"), join(imagests, i + '_0000.nii.gz'))
+ shutil.copy(join(imagests_source, i + "_DE.nii.gz"), join(imagests, i + '_0001.nii.gz'))
+ shutil.copy(join(imagests_source, i + "_T2.nii.gz"), join(imagests, i + '_0002.nii.gz'))
+
+ generate_dataset_json(join(out_base, 'dataset.json'),
+ imagestr,
+ None,
+ ('C0', 'DE', 'T2'),
+ {
+ 0: 'background',
+ 1: "left ventricular (LV) blood pool",
+ 2: "right ventricular blood pool",
+ 3: "LV normal myocardium",
+ 4: "LV myocardial edema",
+ 5: "LV myocardial scars",
+ },
+ task_name,
+ license='see http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/myops20/index.html',
+ dataset_description='see http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/myops20/index.html',
+ dataset_reference='http://www.sdspeople.fudan.edu.cn/zhuangxiahai/0/myops20/index.html',
+ dataset_release='0')
+
+ # REMEMBER THAT TEST SET INFERENCE WILL REQUIRE YOU CONVERT THE LABELS BACK TO THEIR CONVENTION
+ # use convert_labels_back_to_myops for that!
+ # man I am such a nice person. Love you guys.
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/__init__.py b/nnunet/dataset_conversion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..568ed9f17c4948f9699c185ba1dc8d2a1b494914
--- /dev/null
+++ b/nnunet/dataset_conversion/__init__.py
@@ -0,0 +1,3 @@
+from __future__ import absolute_import
+
+from . import *
\ No newline at end of file
diff --git a/nnunet/dataset_conversion/utils.py b/nnunet/dataset_conversion/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ebc75b2c664320066d32122994ad055fe081eaf
--- /dev/null
+++ b/nnunet/dataset_conversion/utils.py
@@ -0,0 +1,76 @@
+
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from typing import Tuple
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def get_identifiers_from_splitted_files(folder: str):
+ uniques = np.unique([i[:-12] for i in subfiles(folder, suffix='.nii.gz', join=False)])
+ return uniques
+
+
+def generate_dataset_json(output_file: str, imagesTr_dir: str, imagesTs_dir: str, modalities: Tuple,
+ labels: dict, dataset_name: str, sort_keys=True, license: str = "hands off!", dataset_description: str = "",
+ dataset_reference="", dataset_release='0.0'):
+ """
+ :param output_file: This needs to be the full path to the dataset.json you intend to write, so
+ output_file='DATASET_PATH/dataset.json' where the folder DATASET_PATH points to is the one with the
+ imagesTr and labelsTr subfolders
+ :param imagesTr_dir: path to the imagesTr folder of that dataset
+ :param imagesTs_dir: path to the imagesTs folder of that dataset. Can be None
+ :param modalities: tuple of strings with modality names. must be in the same order as the images (first entry
+ corresponds to _0000.nii.gz, etc). Example: ('T1', 'T2', 'FLAIR').
+ :param labels: dict with int->str (key->value) mapping the label IDs to label names. Note that 0 is always
+ supposed to be background! Example: {0: 'background', 1: 'edema', 2: 'enhancing tumor'}
+ :param dataset_name: The name of the dataset. Can be anything you want
+ :param sort_keys: In order to sort or not, the keys in dataset.json
+ :param license:
+ :param dataset_description:
+ :param dataset_reference: website of the dataset, if available
+ :param dataset_release:
+ :return:
+ """
+ train_identifiers = get_identifiers_from_splitted_files(imagesTr_dir)
+
+ if imagesTs_dir is not None:
+ test_identifiers = get_identifiers_from_splitted_files(imagesTs_dir)
+ else:
+ test_identifiers = []
+
+ json_dict = {}
+ json_dict['name'] = dataset_name
+ json_dict['description'] = dataset_description
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = dataset_reference
+ json_dict['licence'] = license
+ json_dict['release'] = dataset_release
+ json_dict['modality'] = {str(i): modalities[i] for i in range(len(modalities))}
+ json_dict['labels'] = {str(i): labels[i] for i in labels.keys()}
+
+ json_dict['numTraining'] = len(train_identifiers)
+ json_dict['numTest'] = len(test_identifiers)
+ json_dict['training'] = [
+ {'image': "./imagesTr/%s.nii.gz" % i, "label": "./labelsTr/%s.nii.gz" % i} for i
+ in
+ train_identifiers]
+ json_dict['test'] = ["./imagesTs/%s.nii.gz" % i for i in test_identifiers]
+
+ if not output_file.endswith("dataset.json"):
+ print("WARNING: output file name is not dataset.json! This may be intentional or not. You decide. "
+ "Proceeding anyways...")
+ save_json(json_dict, os.path.join(output_file), sort_keys=sort_keys)
diff --git a/nnunet/evaluation/__init__.py b/nnunet/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/evaluation/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/evaluation/add_dummy_task_with_mean_over_all_tasks.py b/nnunet/evaluation/add_dummy_task_with_mean_over_all_tasks.py
new file mode 100644
index 0000000000000000000000000000000000000000..670bf20c71e777d34afac31a729e0da2e6d9c6cd
--- /dev/null
+++ b/nnunet/evaluation/add_dummy_task_with_mean_over_all_tasks.py
@@ -0,0 +1,77 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import subfiles
+import os
+from collections import OrderedDict
+
+folder = "/home/fabian/drives/E132-Projekte/Projects/2018_MedicalDecathlon/Leaderboard"
+task_descriptors = ['2D final 2',
+ '2D final, less pool, dc and topK, fold0',
+ '2D final pseudo3d 7, fold0',
+ '2D final, less pool, dc and ce, fold0',
+ '3D stage0 final 2, fold0',
+ '3D fullres final 2, fold0']
+task_ids_with_no_stage0 = ["Task001_BrainTumour", "Task004_Hippocampus", "Task005_Prostate"]
+
+mean_scores = OrderedDict()
+for t in task_descriptors:
+ mean_scores[t] = OrderedDict()
+
+json_files = subfiles(folder, True, None, ".json", True)
+json_files = [i for i in json_files if not i.split("/")[-1].startswith(".")] # stupid mac
+for j in json_files:
+ with open(j, 'r') as f:
+ res = json.load(f)
+ task = res['task']
+ if task != "Task999_ALL":
+ name = res['name']
+ if name in task_descriptors:
+ if task not in list(mean_scores[name].keys()):
+ mean_scores[name][task] = res['results']['mean']['mean']
+ else:
+ raise RuntimeError("duplicate task %s for description %s" % (task, name))
+
+for t in task_ids_with_no_stage0:
+ mean_scores["3D stage0 final 2, fold0"][t] = mean_scores["3D fullres final 2, fold0"][t]
+
+a = set()
+for i in mean_scores.keys():
+ a = a.union(list(mean_scores[i].keys()))
+
+for i in mean_scores.keys():
+ try:
+ for t in list(a):
+ assert t in mean_scores[i].keys(), "did not find task %s for experiment %s" % (t, i)
+ new_res = OrderedDict()
+ new_res['name'] = i
+ new_res['author'] = "Fabian"
+ new_res['task'] = "Task999_ALL"
+ new_res['results'] = OrderedDict()
+ new_res['results']['mean'] = OrderedDict()
+ new_res['results']['mean']['mean'] = OrderedDict()
+ tasks = list(mean_scores[i].keys())
+ metrics = mean_scores[i][tasks[0]].keys()
+ for m in metrics:
+ foreground_values = [mean_scores[i][n][m] for n in tasks]
+ new_res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
+ output_fname = i.replace(" ", "_") + "_globalMean.json"
+ with open(os.path.join(folder, output_fname), 'w') as f:
+ json.dump(new_res, f)
+ except AssertionError:
+ print("could not process experiment %s" % i)
+ print("did not find task %s for experiment %s" % (t, i))
+
diff --git a/nnunet/evaluation/add_mean_dice_to_json.py b/nnunet/evaluation/add_mean_dice_to_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f428b0567dcf1a99c6cfca90682f1c465208d8
--- /dev/null
+++ b/nnunet/evaluation/add_mean_dice_to_json.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import subfiles
+from collections import OrderedDict
+
+
+def foreground_mean(filename):
+ with open(filename, 'r') as f:
+ res = json.load(f)
+ class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean')])
+ class_ids = class_ids[class_ids != 0]
+ class_ids = class_ids[class_ids != -1]
+ class_ids = class_ids[class_ids != 99]
+
+ tmp = res['results']['mean'].get('99')
+ if tmp is not None:
+ _ = res['results']['mean'].pop('99')
+
+ metrics = res['results']['mean']['1'].keys()
+ res['results']['mean']["mean"] = OrderedDict()
+ for m in metrics:
+ foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids]
+ res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
+ with open(filename, 'w') as f:
+ json.dump(res, f, indent=4, sort_keys=True)
+
+
+def run_in_folder(folder):
+ json_files = subfiles(folder, True, None, ".json", True)
+ json_files = [i for i in json_files if not i.split("/")[-1].startswith(".") and not i.endswith("_globalMean.json")] # stupid mac
+ for j in json_files:
+ foreground_mean(j)
+
+
+if __name__ == "__main__":
+ folder = "/media/fabian/Results/nnUNetOutput_final/summary_jsons"
+ run_in_folder(folder)
diff --git a/nnunet/evaluation/collect_results_files.py b/nnunet/evaluation/collect_results_files.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1ea6bfec9f3bedbf75e1fe63b6c232360bc34dd
--- /dev/null
+++ b/nnunet/evaluation/collect_results_files.py
@@ -0,0 +1,48 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import shutil
+from batchgenerators.utilities.file_and_folder_operations import subdirs, subfiles
+
+
+def crawl_and_copy(current_folder, out_folder, prefix="fabian_", suffix="ummary.json"):
+ """
+ This script will run recursively through all subfolders of current_folder and copy all files that end with
+ suffix with some automatically generated prefix into out_folder
+ :param current_folder:
+ :param out_folder:
+ :param prefix:
+ :return:
+ """
+ s = subdirs(current_folder, join=False)
+ f = subfiles(current_folder, join=False)
+ f = [i for i in f if i.endswith(suffix)]
+ if current_folder.find("fold0") != -1:
+ for fl in f:
+ shutil.copy(os.path.join(current_folder, fl), os.path.join(out_folder, prefix+fl))
+ for su in s:
+ if prefix == "":
+ add = su
+ else:
+ add = "__" + su
+ crawl_and_copy(os.path.join(current_folder, su), out_folder, prefix=prefix+add)
+
+
+if __name__ == "__main__":
+ from nnunet.paths import network_training_output_dir
+ output_folder = "/home/fabian/PhD/results/nnUNetV2/leaderboard"
+ crawl_and_copy(network_training_output_dir, output_folder)
+ from nnunet.evaluation.add_mean_dice_to_json import run_in_folder
+ run_in_folder(output_folder)
diff --git a/nnunet/evaluation/evaluator.py b/nnunet/evaluation/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c74fd11d546b4081a4d57434dde459161fa9869
--- /dev/null
+++ b/nnunet/evaluation/evaluator.py
@@ -0,0 +1,483 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import collections
+import inspect
+import json
+import hashlib
+from datetime import datetime
+from multiprocessing.pool import Pool
+import numpy as np
+import pandas as pd
+import SimpleITK as sitk
+from nnunet.evaluation.metrics import ConfusionMatrix, ALL_METRICS
+from batchgenerators.utilities.file_and_folder_operations import save_json, subfiles, join
+from collections import OrderedDict
+
+
+class Evaluator:
+ """Object that holds test and reference segmentations with label information
+ and computes a number of metrics on the two. 'labels' must either be an
+ iterable of numeric values (or tuples thereof) or a dictionary with string
+ names and numeric values.
+ """
+
+ default_metrics = [
+ "False Positive Rate",
+ "Dice",
+ "Jaccard",
+ "Precision",
+ "Recall",
+ "Accuracy",
+ "False Omission Rate",
+ "Negative Predictive Value",
+ "False Negative Rate",
+ "True Negative Rate",
+ "False Discovery Rate",
+ "Total Positives Test",
+ "Total Positives Reference"
+ ]
+
+ default_advanced_metrics = [
+ #"Hausdorff Distance",
+ "Hausdorff Distance 95",
+ #"Avg. Surface Distance",
+ #"Avg. Symmetric Surface Distance"
+ ]
+
+ def __init__(self,
+ test=None,
+ reference=None,
+ labels=None,
+ metrics=None,
+ advanced_metrics=None,
+ nan_for_nonexisting=True):
+
+ self.test = None
+ self.reference = None
+ self.confusion_matrix = ConfusionMatrix()
+ self.labels = None
+ self.nan_for_nonexisting = nan_for_nonexisting
+ self.result = None
+
+ self.metrics = []
+ if metrics is None:
+ for m in self.default_metrics:
+ self.metrics.append(m)
+ else:
+ for m in metrics:
+ self.metrics.append(m)
+
+ self.advanced_metrics = []
+ if advanced_metrics is None:
+ for m in self.default_advanced_metrics:
+ self.advanced_metrics.append(m)
+ else:
+ for m in advanced_metrics:
+ self.advanced_metrics.append(m)
+
+ self.set_reference(reference)
+ self.set_test(test)
+ if labels is not None:
+ self.set_labels(labels)
+ else:
+ if test is not None and reference is not None:
+ self.construct_labels()
+
+ def set_test(self, test):
+ """Set the test segmentation."""
+
+ self.test = test
+
+ def set_reference(self, reference):
+ """Set the reference segmentation."""
+
+ self.reference = reference
+
+ def set_labels(self, labels):
+ """Set the labels.
+ :param labels= may be a dictionary (int->str), a set (of ints), a tuple (of ints) or a list (of ints). Labels
+ will only have names if you pass a dictionary"""
+
+ if isinstance(labels, dict):
+ self.labels = collections.OrderedDict(labels)
+ elif isinstance(labels, set):
+ self.labels = list(labels)
+ elif isinstance(labels, np.ndarray):
+ self.labels = [i for i in labels]
+ elif isinstance(labels, (list, tuple)):
+ self.labels = labels
+ else:
+ raise TypeError("Can only handle dict, list, tuple, set & numpy array, but input is of type {}".format(type(labels)))
+
+ def construct_labels(self):
+ """Construct label set from unique entries in segmentations."""
+
+ if self.test is None and self.reference is None:
+ raise ValueError("No test or reference segmentations.")
+ elif self.test is None:
+ labels = np.unique(self.reference)
+ else:
+ labels = np.union1d(np.unique(self.test),
+ np.unique(self.reference))
+ self.labels = list(map(lambda x: int(x), labels))
+
+ def set_metrics(self, metrics):
+ """Set evaluation metrics"""
+
+ if isinstance(metrics, set):
+ self.metrics = list(metrics)
+ elif isinstance(metrics, (list, tuple, np.ndarray)):
+ self.metrics = metrics
+ else:
+ raise TypeError("Can only handle list, tuple, set & numpy array, but input is of type {}".format(type(metrics)))
+
+ def add_metric(self, metric):
+
+ if metric not in self.metrics:
+ self.metrics.append(metric)
+
+ def evaluate(self, test=None, reference=None, advanced=False, **metric_kwargs):
+ """Compute metrics for segmentations."""
+ if test is not None:
+ self.set_test(test)
+
+ if reference is not None:
+ self.set_reference(reference)
+
+ if self.test is None or self.reference is None:
+ raise ValueError("Need both test and reference segmentations.")
+
+ if self.labels is None:
+ self.construct_labels()
+
+ self.metrics.sort()
+
+ # get functions for evaluation
+ # somewhat convoluted, but allows users to define additonal metrics
+ # on the fly, e.g. inside an IPython console
+ _funcs = {m: ALL_METRICS[m] for m in self.metrics + self.advanced_metrics}
+ frames = inspect.getouterframes(inspect.currentframe())
+ for metric in self.metrics:
+ for f in frames:
+ if metric in f[0].f_locals:
+ _funcs[metric] = f[0].f_locals[metric]
+ break
+ else:
+ if metric in _funcs:
+ continue
+ else:
+ raise NotImplementedError(
+ "Metric {} not implemented.".format(metric))
+
+ # get results
+ self.result = OrderedDict()
+
+ eval_metrics = self.metrics
+ if advanced:
+ eval_metrics += self.advanced_metrics
+
+ if isinstance(self.labels, dict):
+
+ for label, name in self.labels.items():
+ k = str(name)
+ self.result[k] = OrderedDict()
+ if not hasattr(label, "__iter__"):
+ self.confusion_matrix.set_test(self.test == label)
+ self.confusion_matrix.set_reference(self.reference == label)
+ else:
+ current_test = 0
+ current_reference = 0
+ for l in label:
+ current_test += (self.test == l)
+ current_reference += (self.reference == l)
+ self.confusion_matrix.set_test(current_test)
+ self.confusion_matrix.set_reference(current_reference)
+ for metric in eval_metrics:
+ self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
+ nan_for_nonexisting=self.nan_for_nonexisting,
+ **metric_kwargs)
+
+ else:
+
+ for i, l in enumerate(self.labels):
+ k = str(l)
+ self.result[k] = OrderedDict()
+ self.confusion_matrix.set_test(self.test == l)
+ self.confusion_matrix.set_reference(self.reference == l)
+ for metric in eval_metrics:
+ self.result[k][metric] = _funcs[metric](confusion_matrix=self.confusion_matrix,
+ nan_for_nonexisting=self.nan_for_nonexisting,
+ **metric_kwargs)
+
+ return self.result
+
+ def to_dict(self):
+
+ if self.result is None:
+ self.evaluate()
+ return self.result
+
+ def to_array(self):
+ """Return result as numpy array (labels x metrics)."""
+
+ if self.result is None:
+ self.evaluate
+
+ result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
+
+ a = np.zeros((len(self.labels), len(result_metrics)), dtype=np.float32)
+
+ if isinstance(self.labels, dict):
+ for i, label in enumerate(self.labels.keys()):
+ for j, metric in enumerate(result_metrics):
+ a[i][j] = self.result[self.labels[label]][metric]
+ else:
+ for i, label in enumerate(self.labels):
+ for j, metric in enumerate(result_metrics):
+ a[i][j] = self.result[label][metric]
+
+ return a
+
+ def to_pandas(self):
+ """Return result as pandas DataFrame."""
+
+ a = self.to_array()
+
+ if isinstance(self.labels, dict):
+ labels = list(self.labels.values())
+ else:
+ labels = self.labels
+
+ result_metrics = sorted(self.result[list(self.result.keys())[0]].keys())
+
+ return pd.DataFrame(a, index=labels, columns=result_metrics)
+
+
+class NiftiEvaluator(Evaluator):
+
+ def __init__(self, *args, **kwargs):
+
+ self.test_nifti = None
+ self.reference_nifti = None
+ super(NiftiEvaluator, self).__init__(*args, **kwargs)
+
+ def set_test(self, test):
+ """Set the test segmentation."""
+
+ if test is not None:
+ self.test_nifti = sitk.ReadImage(test)
+ super(NiftiEvaluator, self).set_test(sitk.GetArrayFromImage(self.test_nifti))
+ else:
+ self.test_nifti = None
+ super(NiftiEvaluator, self).set_test(test)
+
+ def set_reference(self, reference):
+ """Set the reference segmentation."""
+
+ if reference is not None:
+ self.reference_nifti = sitk.ReadImage(reference)
+ super(NiftiEvaluator, self).set_reference(sitk.GetArrayFromImage(self.reference_nifti))
+ else:
+ self.reference_nifti = None
+ super(NiftiEvaluator, self).set_reference(reference)
+
+ def evaluate(self, test=None, reference=None, voxel_spacing=None, **metric_kwargs):
+
+ if voxel_spacing is None:
+ voxel_spacing = np.array(self.test_nifti.GetSpacing())[::-1]
+ metric_kwargs["voxel_spacing"] = voxel_spacing
+
+ return super(NiftiEvaluator, self).evaluate(test, reference, **metric_kwargs)
+
+
+def run_evaluation(args):
+ test, ref, evaluator, metric_kwargs = args
+ # evaluate
+ evaluator.set_test(test)
+ evaluator.set_reference(ref)
+ if evaluator.labels is None:
+ evaluator.construct_labels()
+ current_scores = evaluator.evaluate(**metric_kwargs)
+ if type(test) == str:
+ current_scores["test"] = test
+ if type(ref) == str:
+ current_scores["reference"] = ref
+ return current_scores
+
+
+def aggregate_scores(test_ref_pairs,
+ evaluator=NiftiEvaluator,
+ labels=None,
+ nanmean=True,
+ json_output_file=None,
+ json_name="",
+ json_description="",
+ json_author="Fabian",
+ json_task="",
+ num_threads=2,
+ **metric_kwargs):
+ """
+ test = predicted image
+ :param test_ref_pairs:
+ :param evaluator:
+ :param labels: must be a dict of int-> str or a list of int
+ :param nanmean:
+ :param json_output_file:
+ :param json_name:
+ :param json_description:
+ :param json_author:
+ :param json_task:
+ :param metric_kwargs:
+ :return:
+ """
+
+ if type(evaluator) == type:
+ evaluator = evaluator()
+
+ if labels is not None:
+ evaluator.set_labels(labels)
+
+ all_scores = OrderedDict()
+ all_scores["all"] = []
+ all_scores["mean"] = OrderedDict()
+
+ test = [i[0] for i in test_ref_pairs]
+ ref = [i[1] for i in test_ref_pairs]
+ p = Pool(num_threads)
+ all_res = p.map(run_evaluation, zip(test, ref, [evaluator]*len(ref), [metric_kwargs]*len(ref)))
+ p.close()
+ p.join()
+
+ for i in range(len(all_res)):
+ all_scores["all"].append(all_res[i])
+
+ # append score list for mean
+ for label, score_dict in all_res[i].items():
+ if label in ("test", "reference"):
+ continue
+ if label not in all_scores["mean"]:
+ all_scores["mean"][label] = OrderedDict()
+ for score, value in score_dict.items():
+ if score not in all_scores["mean"][label]:
+ all_scores["mean"][label][score] = []
+ all_scores["mean"][label][score].append(value)
+
+ for label in all_scores["mean"]:
+ for score in all_scores["mean"][label]:
+ if nanmean:
+ all_scores["mean"][label][score] = float(np.nanmean(all_scores["mean"][label][score]))
+ else:
+ all_scores["mean"][label][score] = float(np.mean(all_scores["mean"][label][score]))
+
+ # save to file if desired
+ # we create a hopefully unique id by hashing the entire output dictionary
+ if json_output_file is not None:
+ json_dict = OrderedDict()
+ json_dict["name"] = json_name
+ json_dict["description"] = json_description
+ timestamp = datetime.today()
+ json_dict["timestamp"] = str(timestamp)
+ json_dict["task"] = json_task
+ json_dict["author"] = json_author
+ json_dict["results"] = all_scores
+ json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
+ save_json(json_dict, json_output_file)
+
+
+ return all_scores
+
+
+def aggregate_scores_for_experiment(score_file,
+ labels=None,
+ metrics=Evaluator.default_metrics,
+ nanmean=True,
+ json_output_file=None,
+ json_name="",
+ json_description="",
+ json_author="Fabian",
+ json_task=""):
+
+ scores = np.load(score_file)
+ scores_mean = scores.mean(0)
+ if labels is None:
+ labels = list(map(str, range(scores.shape[1])))
+
+ results = []
+ results_mean = OrderedDict()
+ for i in range(scores.shape[0]):
+ results.append(OrderedDict())
+ for l, label in enumerate(labels):
+ results[-1][label] = OrderedDict()
+ results_mean[label] = OrderedDict()
+ for m, metric in enumerate(metrics):
+ results[-1][label][metric] = float(scores[i][l][m])
+ results_mean[label][metric] = float(scores_mean[l][m])
+
+ json_dict = OrderedDict()
+ json_dict["name"] = json_name
+ json_dict["description"] = json_description
+ timestamp = datetime.today()
+ json_dict["timestamp"] = str(timestamp)
+ json_dict["task"] = json_task
+ json_dict["author"] = json_author
+ json_dict["results"] = {"all": results, "mean": results_mean}
+ json_dict["id"] = hashlib.md5(json.dumps(json_dict).encode("utf-8")).hexdigest()[:12]
+ if json_output_file is not None:
+ json_output_file = open(json_output_file, "w")
+ json.dump(json_dict, json_output_file, indent=4, separators=(",", ": "))
+ json_output_file.close()
+
+ return json_dict
+
+
+def evaluate_folder(folder_with_gts: str, folder_with_predictions: str, labels: tuple, **metric_kwargs):
+ """
+ writes a summary.json to folder_with_predictions
+ :param folder_with_gts: folder where the ground truth segmentations are saved. Must be nifti files.
+ :param folder_with_predictions: folder where the predicted segmentations are saved. Must be nifti files.
+ :param labels: tuple of int with the labels in the dataset. For example (0, 1, 2, 3) for Task001_BrainTumour.
+ :return:
+ """
+ files_gt = subfiles(folder_with_gts, suffix=".nii.gz", join=False)
+ files_pred = subfiles(folder_with_predictions, suffix=".nii.gz", join=False)
+ assert all([i in files_pred for i in files_gt]), "files missing in folder_with_predictions"
+ assert all([i in files_gt for i in files_pred]), "files missing in folder_with_gts"
+ test_ref_pairs = [(join(folder_with_predictions, i), join(folder_with_gts, i)) for i in files_pred]
+ res = aggregate_scores(test_ref_pairs, json_output_file=join(folder_with_predictions, "summary.json"),
+ num_threads=8, labels=labels, **metric_kwargs)
+ return res
+
+
+def nnunet_evaluate_folder():
+ import argparse
+ parser = argparse.ArgumentParser("Evaluates the segmentations located in the folder pred. Output of this script is "
+ "a json file. At the very bottom of the json file is going to be a 'mean' "
+ "entry with averages metrics across all cases")
+ parser.add_argument('-ref', required=True, type=str, help="Folder containing the reference segmentations in nifti "
+ "format.")
+ parser.add_argument('-pred', required=True, type=str, help="Folder containing the predicted segmentations in nifti "
+ "format. File names must match between the folders!")
+ parser.add_argument('-l', nargs='+', type=int, required=True, help="List of label IDs (integer values) that should "
+ "be evaluated. Best practice is to use all int "
+ "values present in the dataset, so for example "
+ "for LiTS the labels are 0: background, 1: "
+ "liver, 2: tumor. So this argument "
+ "should be -l 1 2. You can if you want also "
+ "evaluate the background label (0) but in "
+ "this case that would not give any useful "
+ "information.")
+ args = parser.parse_args()
+ return evaluate_folder(args.ref, args.pred, args.l)
diff --git a/nnunet/evaluation/metrics.py b/nnunet/evaluation/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..db08f946d30db527d22dc7169209776661b70f99
--- /dev/null
+++ b/nnunet/evaluation/metrics.py
@@ -0,0 +1,406 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from medpy import metric
+
+
+def assert_shape(test, reference):
+
+ assert test.shape == reference.shape, "Shape mismatch: {} and {}".format(
+ test.shape, reference.shape)
+
+
+class ConfusionMatrix:
+
+ def __init__(self, test=None, reference=None):
+
+ self.tp = None
+ self.fp = None
+ self.tn = None
+ self.fn = None
+ self.size = None
+ self.reference_empty = None
+ self.reference_full = None
+ self.test_empty = None
+ self.test_full = None
+ self.set_reference(reference)
+ self.set_test(test)
+
+ def set_test(self, test):
+
+ self.test = test
+ self.reset()
+
+ def set_reference(self, reference):
+
+ self.reference = reference
+ self.reset()
+
+ def reset(self):
+
+ self.tp = None
+ self.fp = None
+ self.tn = None
+ self.fn = None
+ self.size = None
+ self.test_empty = None
+ self.test_full = None
+ self.reference_empty = None
+ self.reference_full = None
+
+ def compute(self):
+
+ if self.test is None or self.reference is None:
+ raise ValueError("'test' and 'reference' must both be set to compute confusion matrix.")
+
+ assert_shape(self.test, self.reference)
+
+ self.tp = int(((self.test != 0) * (self.reference != 0)).sum())
+ self.fp = int(((self.test != 0) * (self.reference == 0)).sum())
+ self.tn = int(((self.test == 0) * (self.reference == 0)).sum())
+ self.fn = int(((self.test == 0) * (self.reference != 0)).sum())
+ self.size = int(np.prod(self.reference.shape, dtype=np.int64))
+ self.test_empty = not np.any(self.test)
+ self.test_full = np.all(self.test)
+ self.reference_empty = not np.any(self.reference)
+ self.reference_full = np.all(self.reference)
+
+ def get_matrix(self):
+
+ for entry in (self.tp, self.fp, self.tn, self.fn):
+ if entry is None:
+ self.compute()
+ break
+
+ return self.tp, self.fp, self.tn, self.fn
+
+ def get_size(self):
+
+ if self.size is None:
+ self.compute()
+ return self.size
+
+ def get_existence(self):
+
+ for case in (self.test_empty, self.test_full, self.reference_empty, self.reference_full):
+ if case is None:
+ self.compute()
+ break
+
+ return self.test_empty, self.test_full, self.reference_empty, self.reference_full
+
+
+def dice(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """2TP / (2TP + FP + FN)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty and reference_empty:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(2. * tp / (2 * tp + fp + fn))
+
+
+def jaccard(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TP / (TP + FP + FN)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty and reference_empty:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(tp / (tp + fp + fn))
+
+
+def precision(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TP / (TP + FP)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(tp / (tp + fp))
+
+
+def sensitivity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TP / (TP + FN)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if reference_empty:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(tp / (tp + fn))
+
+
+def recall(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TP / (TP + FN)"""
+
+ return sensitivity(test, reference, confusion_matrix, nan_for_nonexisting, **kwargs)
+
+
+def specificity(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TN / (TN + FP)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if reference_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(tn / (tn + fp))
+
+
+def accuracy(test=None, reference=None, confusion_matrix=None, **kwargs):
+ """(TP + TN) / (TP + FP + FN + TN)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+
+ return float((tp + tn) / (tp + fp + tn + fn))
+
+
+def fscore(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, beta=1., **kwargs):
+ """(1 + b^2) * TP / ((1 + b^2) * TP + b^2 * FN + FP)"""
+
+ precision_ = precision(test, reference, confusion_matrix, nan_for_nonexisting)
+ recall_ = recall(test, reference, confusion_matrix, nan_for_nonexisting)
+
+ return (1 + beta*beta) * precision_ * recall_ /\
+ ((beta*beta * precision_) + recall_)
+
+
+def false_positive_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """FP / (FP + TN)"""
+
+ return 1 - specificity(test, reference, confusion_matrix, nan_for_nonexisting)
+
+
+def false_omission_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """FN / (TN + FN)"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0.
+
+ return float(fn / (fn + tn))
+
+
+def false_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """FN / (TP + FN)"""
+
+ return 1 - sensitivity(test, reference, confusion_matrix, nan_for_nonexisting)
+
+
+def true_negative_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TN / (TN + FP)"""
+
+ return specificity(test, reference, confusion_matrix, nan_for_nonexisting)
+
+
+def false_discovery_rate(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """FP / (TP + FP)"""
+
+ return 1 - precision(test, reference, confusion_matrix, nan_for_nonexisting)
+
+
+def negative_predictive_value(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, **kwargs):
+ """TN / (TN + FN)"""
+
+ return 1 - false_omission_rate(test, reference, confusion_matrix, nan_for_nonexisting)
+
+
+def total_positives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
+ """TP + FP"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+
+ return tp + fp
+
+
+def total_negatives_test(test=None, reference=None, confusion_matrix=None, **kwargs):
+ """TN + FN"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+
+ return tn + fn
+
+
+def total_positives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
+ """TP + FN"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+
+ return tp + fn
+
+
+def total_negatives_reference(test=None, reference=None, confusion_matrix=None, **kwargs):
+ """TN + FP"""
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ tp, fp, tn, fn = confusion_matrix.get_matrix()
+
+ return tn + fp
+
+
+def hausdorff_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty or test_full or reference_empty or reference_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0
+
+ test, reference = confusion_matrix.test, confusion_matrix.reference
+
+ return metric.hd(test, reference, voxel_spacing, connectivity)
+
+
+def hausdorff_distance_95(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty or test_full or reference_empty or reference_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0
+
+ test, reference = confusion_matrix.test, confusion_matrix.reference
+
+ return metric.hd95(test, reference, voxel_spacing, connectivity)
+
+
+def avg_surface_distance(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty or test_full or reference_empty or reference_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0
+
+ test, reference = confusion_matrix.test, confusion_matrix.reference
+
+ return metric.asd(test, reference, voxel_spacing, connectivity)
+
+
+def avg_surface_distance_symmetric(test=None, reference=None, confusion_matrix=None, nan_for_nonexisting=True, voxel_spacing=None, connectivity=1, **kwargs):
+
+ if confusion_matrix is None:
+ confusion_matrix = ConfusionMatrix(test, reference)
+
+ test_empty, test_full, reference_empty, reference_full = confusion_matrix.get_existence()
+
+ if test_empty or test_full or reference_empty or reference_full:
+ if nan_for_nonexisting:
+ return float("NaN")
+ else:
+ return 0
+
+ test, reference = confusion_matrix.test, confusion_matrix.reference
+
+ return metric.assd(test, reference, voxel_spacing, connectivity)
+
+
+ALL_METRICS = {
+ "False Positive Rate": false_positive_rate,
+ "Dice": dice,
+ "Jaccard": jaccard,
+ "Hausdorff Distance": hausdorff_distance,
+ "Hausdorff Distance 95": hausdorff_distance_95,
+ "Precision": precision,
+ "Recall": recall,
+ "Avg. Symmetric Surface Distance": avg_surface_distance_symmetric,
+ "Avg. Surface Distance": avg_surface_distance,
+ "Accuracy": accuracy,
+ "False Omission Rate": false_omission_rate,
+ "Negative Predictive Value": negative_predictive_value,
+ "False Negative Rate": false_negative_rate,
+ "True Negative Rate": true_negative_rate,
+ "False Discovery Rate": false_discovery_rate,
+ "Total Positives Test": total_positives_test,
+ "Total Negatives Test": total_negatives_test,
+ "Total Positives Reference": total_positives_reference,
+ "total Negatives Reference": total_negatives_reference
+}
diff --git a/nnunet/evaluation/model_selection/__init__.py b/nnunet/evaluation/model_selection/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/evaluation/model_selection/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py b/nnunet/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdfd492b33ce8e58e02bd7933a590b39556eed5f
--- /dev/null
+++ b/nnunet/evaluation/model_selection/collect_all_fold0_results_and_summarize_in_one_csv.py
@@ -0,0 +1,73 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nnunet.evaluation.model_selection.summarize_results_in_one_json import summarize2
+from nnunet.paths import network_training_output_dir
+from batchgenerators.utilities.file_and_folder_operations import *
+
+if __name__ == "__main__":
+ summary_output_folder = join(network_training_output_dir, "summary_jsons_fold0_new")
+ maybe_mkdir_p(summary_output_folder)
+ summarize2(['all'], output_dir=summary_output_folder, folds=(0,))
+
+ results_csv = join(network_training_output_dir, "summary_fold0.csv")
+
+ summary_files = subfiles(summary_output_folder, suffix='.json', join=False)
+
+ with open(results_csv, 'w') as f:
+ for s in summary_files:
+ if s.find("ensemble") == -1:
+ task, network, trainer, plans, validation_folder, folds = s.split("__")
+ else:
+ n1, n2 = s.split("--")
+ n1 = n1[n1.find("ensemble_") + len("ensemble_") :]
+ task = s.split("__")[0]
+ network = "ensemble"
+ trainer = n1
+ plans = n2
+ validation_folder = "none"
+ folds = folds[:-len('.json')]
+ results = load_json(join(summary_output_folder, s))
+ results_mean = results['results']['mean']['mean']['Dice']
+ results_median = results['results']['median']['mean']['Dice']
+ f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task,
+ network, trainer, validation_folder, plans, results_mean, results_median))
+
+ summary_output_folder = join(network_training_output_dir, "summary_jsons_new")
+ maybe_mkdir_p(summary_output_folder)
+ summarize2(['all'], output_dir=summary_output_folder)
+
+ results_csv = join(network_training_output_dir, "summary_allFolds.csv")
+
+ summary_files = subfiles(summary_output_folder, suffix='.json', join=False)
+
+ with open(results_csv, 'w') as f:
+ for s in summary_files:
+ if s.find("ensemble") == -1:
+ task, network, trainer, plans, validation_folder, folds = s.split("__")
+ else:
+ n1, n2 = s.split("--")
+ n1 = n1[n1.find("ensemble_") + len("ensemble_") :]
+ task = s.split("__")[0]
+ network = "ensemble"
+ trainer = n1
+ plans = n2
+ validation_folder = "none"
+ folds = folds[:-len('.json')]
+ results = load_json(join(summary_output_folder, s))
+ results_mean = results['results']['mean']['mean']['Dice']
+ results_median = results['results']['median']['mean']['Dice']
+ f.write("%s,%s,%s,%s,%s,%02.4f,%02.4f\n" % (task,
+ network, trainer, validation_folder, plans, results_mean, results_median))
+
diff --git a/nnunet/evaluation/model_selection/ensemble.py b/nnunet/evaluation/model_selection/ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e0a489d1d95822ef580bbb3d7e2c8f38b2735e4
--- /dev/null
+++ b/nnunet/evaluation/model_selection/ensemble.py
@@ -0,0 +1,123 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import shutil
+from multiprocessing.pool import Pool
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.configuration import default_num_threads
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from nnunet.paths import network_training_output_dir, preprocessing_output_dir
+from nnunet.postprocessing.connected_components import determine_postprocessing
+
+
+def merge(args):
+ file1, file2, properties_file, out_file = args
+ if not isfile(out_file):
+ res1 = np.load(file1)['softmax']
+ res2 = np.load(file2)['softmax']
+ props = load_pickle(properties_file)
+ mn = np.mean((res1, res2), 0)
+ # Softmax probabilities are already at target spacing so this will not do any resampling (resampling parameters
+ # don't matter here)
+ save_segmentation_nifti_from_softmax(mn, out_file, props, 3, None, None, None, force_separate_z=None,
+ interpolation_order_z=0)
+
+
+def ensemble(training_output_folder1, training_output_folder2, output_folder, task, validation_folder, folds, allow_ensembling: bool = True):
+ print("\nEnsembling folders\n", training_output_folder1, "\n", training_output_folder2)
+
+ output_folder_base = output_folder
+ output_folder = join(output_folder_base, "ensembled_raw")
+
+ # only_keep_largest_connected_component is the same for all stages
+ dataset_directory = join(preprocessing_output_dir, task)
+ plans = load_pickle(join(training_output_folder1, "plans.pkl")) # we need this only for the labels
+
+ files1 = []
+ files2 = []
+ property_files = []
+ out_files = []
+ gt_segmentations = []
+
+ folder_with_gt_segs = join(dataset_directory, "gt_segmentations")
+ # in the correct shape and we need the original geometry to restore the niftis
+
+ for f in folds:
+ validation_folder_net1 = join(training_output_folder1, "fold_%d" % f, validation_folder)
+ validation_folder_net2 = join(training_output_folder2, "fold_%d" % f, validation_folder)
+
+ if not isdir(validation_folder_net1):
+ raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnUNet_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1)
+ if not isdir(validation_folder_net2):
+ raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnUNet_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2)
+
+ # we need to ensure the validation was successful. We can verify this via the presence of the summary.json file
+ if not isfile(join(validation_folder_net1, 'summary.json')):
+ raise AssertionError("Validation directory incomplete: %s. Please rerun validation with `nnUNet_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net1)
+ if not isfile(join(validation_folder_net2, 'summary.json')):
+ raise AssertionError("Validation directory missing: %s. Please rerun validation with `nnUNet_train CONFIG TRAINER TASK FOLD -val --npz`" % validation_folder_net2)
+
+ patient_identifiers1_npz = [i[:-4] for i in subfiles(validation_folder_net1, False, None, 'npz', True)]
+ patient_identifiers2_npz = [i[:-4] for i in subfiles(validation_folder_net2, False, None, 'npz', True)]
+
+ # we don't do postprocessing anymore so there should not be any of that noPostProcess
+ patient_identifiers1_nii = [i[:-7] for i in subfiles(validation_folder_net1, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')]
+ patient_identifiers2_nii = [i[:-7] for i in subfiles(validation_folder_net2, False, None, suffix='nii.gz', sort=True) if not i.endswith("noPostProcess.nii.gz") and not i.endswith('_postprocessed.nii.gz')]
+
+ if not all([i in patient_identifiers1_npz for i in patient_identifiers1_nii]):
+ raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net1))
+ if not all([i in patient_identifiers2_npz for i in patient_identifiers2_nii]):
+ raise AssertionError("Missing npz files in folder %s. Please run the validation for all models and folds with the '--npz' flag." % (validation_folder_net2))
+
+ patient_identifiers1_npz.sort()
+ patient_identifiers2_npz.sort()
+
+ assert all([i == j for i, j in zip(patient_identifiers1_npz, patient_identifiers2_npz)]), "npz filenames do not match. This should not happen."
+
+ maybe_mkdir_p(output_folder)
+
+ for p in patient_identifiers1_npz:
+ files1.append(join(validation_folder_net1, p + '.npz'))
+ files2.append(join(validation_folder_net2, p + '.npz'))
+ property_files.append(join(validation_folder_net1, p) + ".pkl")
+ out_files.append(join(output_folder, p + ".nii.gz"))
+ gt_segmentations.append(join(folder_with_gt_segs, p + ".nii.gz"))
+
+ p = Pool(default_num_threads)
+ p.map(merge, zip(files1, files2, property_files, out_files))
+ p.close()
+ p.join()
+
+ if not isfile(join(output_folder, "summary.json")) and len(out_files) > 0:
+ aggregate_scores(tuple(zip(out_files, gt_segmentations)), labels=plans['all_classes'],
+ json_output_file=join(output_folder, "summary.json"), json_task=task,
+ json_name=task + "__" + output_folder_base.split("/")[-1], num_threads=default_num_threads)
+
+ if allow_ensembling and not isfile(join(output_folder_base, "postprocessing.json")):
+ # now lets also look at postprocessing. We cannot just take what we determined in cross-validation and apply it
+ # here because things may have changed and may also be too inconsistent between the two networks
+ determine_postprocessing(output_folder_base, folder_with_gt_segs, "ensembled_raw", "temp",
+ "ensembled_postprocessed", default_num_threads, dice_threshold=0)
+
+ out_dir_all_json = join(network_training_output_dir, "summary_jsons")
+ json_out = load_json(join(output_folder_base, "ensembled_postprocessed", "summary.json"))
+
+ json_out["experiment_name"] = output_folder_base.split("/")[-1]
+ save_json(json_out, join(output_folder_base, "ensembled_postprocessed", "summary.json"))
+
+ maybe_mkdir_p(out_dir_all_json)
+ shutil.copy(join(output_folder_base, "ensembled_postprocessed", "summary.json"),
+ join(out_dir_all_json, "%s__%s.json" % (task, output_folder_base.split("/")[-1])))
diff --git a/nnunet/evaluation/model_selection/figure_out_what_to_submit.py b/nnunet/evaluation/model_selection/figure_out_what_to_submit.py
new file mode 100644
index 0000000000000000000000000000000000000000..13522fd94d9be0f34057c4224dc3a2f39e0d7c1d
--- /dev/null
+++ b/nnunet/evaluation/model_selection/figure_out_what_to_submit.py
@@ -0,0 +1,235 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import shutil
+from itertools import combinations
+import nnunet
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.evaluation.add_mean_dice_to_json import foreground_mean
+from nnunet.evaluation.evaluator import evaluate_folder
+from nnunet.evaluation.model_selection.ensemble import ensemble
+from nnunet.paths import network_training_output_dir
+import numpy as np
+from subprocess import call
+from nnunet.postprocessing.consolidate_postprocessing import consolidate_folds, collect_cv_niftis
+from nnunet.utilities.folder_names import get_output_folder_name
+from nnunet.paths import default_cascade_trainer, default_trainer, default_plans_identifier
+
+
+def find_task_name(folder, task_id):
+ candidates = subdirs(folder, prefix="Task%03.0d_" % task_id, join=False)
+ assert len(candidates) > 0, "no candidate for Task id %d found in folder %s" % (task_id, folder)
+ assert len(candidates) == 1, "more than one candidate for Task id %d found in folder %s" % (task_id, folder)
+ return candidates[0]
+
+
+def get_mean_foreground_dice(json_file):
+ results = load_json(json_file)
+ return get_foreground_mean(results)
+
+
+def get_foreground_mean(results):
+ results_mean = results['results']['mean']
+ dice_scores = [results_mean[i]['Dice'] for i in results_mean.keys() if i != "0" and i != 'mean']
+ return np.mean(dice_scores)
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold "
+ "cross-validation. Running this script requires all models to have been run "
+ "already. This script will summarize the results of the five folds of all "
+ "models in one json each for easy interpretability")
+
+ parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres',
+ '3d_cascade_fullres'])
+ parser.add_argument("-t", '--task_ids', nargs="+", required=True)
+
+ parser.add_argument("-tr", type=str, required=False, default=default_trainer,
+ help="nnUNetTrainer class. Default: %s" % default_trainer)
+ parser.add_argument("-ctr", type=str, required=False, default=default_cascade_trainer,
+ help="nnUNetTrainer class for cascade model. Default: %s" % default_cascade_trainer)
+ parser.add_argument("-pl", type=str, required=False, default=default_plans_identifier,
+ help="plans name, Default: %s" % default_plans_identifier)
+ parser.add_argument('-f', '--folds', nargs='+', default=(0, 1, 2, 3, 4), help="Use this if you have non-standard "
+ "folds. Experienced users only.")
+ parser.add_argument('--disable_ensembling', required=False, default=False, action='store_true',
+ help='Set this flag to disable the use of ensembling. This will find the best single '
+ 'configuration for each task.')
+ parser.add_argument("--disable_postprocessing", required=False, default=False, action="store_true",
+ help="Set this flag if you want to disable the use of postprocessing")
+
+ args = parser.parse_args()
+ tasks = [int(i) for i in args.task_ids]
+
+ models = args.models
+ tr = args.tr
+ trc = args.ctr
+ pl = args.pl
+ disable_ensembling = args.disable_ensembling
+ disable_postprocessing = args.disable_postprocessing
+ folds = tuple(int(i) for i in args.folds)
+
+ validation_folder = "validation_raw"
+
+ # this script now acts independently from the summary jsons. That was unnecessary
+ id_task_mapping = {}
+
+ for t in tasks:
+ # first collect pure model performance
+ results = {}
+ all_results = {}
+ valid_models = []
+ for m in models:
+ if m == "3d_cascade_fullres":
+ trainer = trc
+ else:
+ trainer = tr
+
+ if t not in id_task_mapping.keys():
+ task_name = find_task_name(get_output_folder_name(m), t)
+ id_task_mapping[t] = task_name
+
+ output_folder = get_output_folder_name(m, id_task_mapping[t], trainer, pl)
+ if not isdir(output_folder):
+ raise RuntimeError("Output folder for model %s is missing, expected: %s" % (m, output_folder))
+
+ if disable_postprocessing:
+ # we need to collect the predicted niftis from the 5-fold cv and evaluate them against the ground truth
+ cv_niftis_folder = join(output_folder, 'cv_niftis_raw')
+
+ if not isfile(join(cv_niftis_folder, 'summary.json')):
+ print(t, m, ': collecting niftis from 5-fold cv')
+ if isdir(cv_niftis_folder):
+ shutil.rmtree(cv_niftis_folder)
+
+ collect_cv_niftis(output_folder, cv_niftis_folder, validation_folder, folds)
+
+ niftis_gt = subfiles(join(output_folder, "gt_niftis"), suffix='.nii.gz', join=False)
+ niftis_cv = subfiles(cv_niftis_folder, suffix='.nii.gz', join=False)
+ if not all([i in niftis_gt for i in niftis_cv]):
+ raise AssertionError("It does not seem like you trained all the folds! Train " \
+ "all folds first! There are %d gt niftis in %s but only " \
+ "%d predicted niftis in %s" % (len(niftis_gt), niftis_gt,
+ len(niftis_cv), niftis_cv))
+
+ # load a summary file so that we can know what class labels to expect
+ summary_fold0 = load_json(join(output_folder, "fold_%d" % folds[0], validation_folder,
+ "summary.json"))['results']['mean']
+ # read classes from summary.json
+ classes = tuple((int(i) for i in summary_fold0.keys()))
+
+ # evaluate the cv niftis
+ print(t, m, ': evaluating 5-fold cv results')
+ evaluate_folder(join(output_folder, "gt_niftis"), cv_niftis_folder, classes)
+
+ else:
+ postprocessing_json = join(output_folder, "postprocessing.json")
+ cv_niftis_folder = join(output_folder, "cv_niftis_raw")
+
+ # we need cv_niftis_postprocessed to know the single model performance. And we need the
+ # postprocessing_json. If either of those is missing, rerun consolidate_folds
+ if not isfile(postprocessing_json) or not isdir(cv_niftis_folder):
+ print("running missing postprocessing for %s and model %s" % (id_task_mapping[t], m))
+ consolidate_folds(output_folder, folds=folds)
+
+ assert isfile(postprocessing_json), "Postprocessing json missing, expected: %s" % postprocessing_json
+ assert isdir(cv_niftis_folder), "Folder with niftis from CV missing, expected: %s" % cv_niftis_folder
+
+ # obtain mean foreground dice
+ summary_file = join(cv_niftis_folder, "summary.json")
+ results[m] = get_mean_foreground_dice(summary_file)
+ foreground_mean(summary_file)
+ all_results[m] = load_json(summary_file)['results']['mean']
+ valid_models.append(m)
+
+ if not disable_ensembling:
+ # now run ensembling and add ensembling to results
+ print("\nI will now ensemble combinations of the following models:\n", valid_models)
+ if len(valid_models) > 1:
+ for m1, m2 in combinations(valid_models, 2):
+
+ trainer_m1 = trc if m1 == "3d_cascade_fullres" else tr
+ trainer_m2 = trc if m2 == "3d_cascade_fullres" else tr
+
+ ensemble_name = "ensemble_" + m1 + "__" + trainer_m1 + "__" + pl + "--" + m2 + "__" + trainer_m2 + "__" + pl
+ output_folder_base = join(network_training_output_dir, "ensembles", id_task_mapping[t], ensemble_name)
+ maybe_mkdir_p(output_folder_base)
+
+ network1_folder = get_output_folder_name(m1, id_task_mapping[t], trainer_m1, pl)
+ network2_folder = get_output_folder_name(m2, id_task_mapping[t], trainer_m2, pl)
+
+ print("ensembling", network1_folder, network2_folder)
+ ensemble(network1_folder, network2_folder, output_folder_base, id_task_mapping[t], validation_folder, folds, allow_ensembling=not disable_postprocessing)
+ # ensembling will automatically do postprocessingget_foreground_mean
+
+ # now get result of ensemble
+ results[ensemble_name] = get_mean_foreground_dice(join(output_folder_base, "ensembled_raw", "summary.json"))
+ summary_file = join(output_folder_base, "ensembled_raw", "summary.json")
+ foreground_mean(summary_file)
+ all_results[ensemble_name] = load_json(summary_file)['results']['mean']
+
+ # now print all mean foreground dice and highlight the best
+ foreground_dices = list(results.values())
+ best = np.max(foreground_dices)
+ for k, v in results.items():
+ print(k, v)
+
+ predict_str = ""
+ best_model = None
+ for k, v in results.items():
+ if v == best:
+ print("%s submit model %s" % (id_task_mapping[t], k), v)
+ best_model = k
+ print("\nHere is how you should predict test cases. Run in sequential order and replace all input and output folder names with your personalized ones\n")
+ if k.startswith("ensemble"):
+ tmp = k[len("ensemble_"):]
+ model1, model2 = tmp.split("--")
+ m1, t1, pl1 = model1.split("__")
+ m2, t2, pl2 = model2.split("__")
+ predict_str += "nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + m1 + " -p " + pl + " -t " + \
+ id_task_mapping[t] + " -z\n"
+ predict_str += "nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL2 -tr " + tr + " -ctr " + trc + " -m " + m2 + " -p " + pl + " -t " + \
+ id_task_mapping[t] + " -z\n"
+
+ if not disable_postprocessing:
+ predict_str += "nnUNet_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER -pp " + join(network_training_output_dir, "ensembles", id_task_mapping[t], k, "postprocessing.json") + "\n"
+ else:
+ predict_str += "nnUNet_ensemble -f OUTPUT_FOLDER_MODEL1 OUTPUT_FOLDER_MODEL2 -o OUTPUT_FOLDER\n"
+ else:
+ predict_str += "nnUNet_predict -i FOLDER_WITH_TEST_CASES -o OUTPUT_FOLDER_MODEL1 -tr " + tr + " -ctr " + trc + " -m " + k + " -p " + pl + " -t " + \
+ id_task_mapping[t] + "\n"
+ print(predict_str)
+
+ summary_folder = join(network_training_output_dir, "ensembles", id_task_mapping[t])
+ maybe_mkdir_p(summary_folder)
+ with open(join(summary_folder, "prediction_commands.txt"), 'w') as f:
+ f.write(predict_str)
+
+ num_classes = len([i for i in all_results[best_model].keys() if i != 'mean' and i != '0'])
+ with open(join(summary_folder, "summary.csv"), 'w') as f:
+ f.write("model")
+ for c in range(1, num_classes + 1):
+ f.write(",class%d" % c)
+ f.write(",average")
+ f.write("\n")
+ for m in all_results.keys():
+ f.write(m)
+ for c in range(1, num_classes + 1):
+ f.write(",%01.4f" % all_results[m][str(c)]["Dice"])
+ f.write(",%01.4f" % all_results[m]['mean']["Dice"])
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/evaluation/model_selection/rank_candidates.py b/nnunet/evaluation/model_selection/rank_candidates.py
new file mode 100644
index 0000000000000000000000000000000000000000..c293da9d777c8ba77b48c6c8195fcd696577c0ee
--- /dev/null
+++ b/nnunet/evaluation/model_selection/rank_candidates.py
@@ -0,0 +1,294 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import network_training_output_dir
+
+if __name__ == "__main__":
+ # run collect_all_fold0_results_and_summarize_in_one_csv.py first
+ summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new")
+ output_file = join(network_training_output_dir, "summary.csv")
+
+ folds = (0, )
+ folds_str = ""
+ for f in folds:
+ folds_str += str(f)
+
+ plans = "nnUNetPlans"
+
+ overwrite_plans = {
+ 'nnUNetTrainerV2_2': ["nnUNetPlans", "nnUNetPlansisoPatchesInVoxels"], # r
+ 'nnUNetTrainerV2': ["nnUNetPlansnonCT", "nnUNetPlansCT2", "nnUNetPlansallConv3x3",
+ "nnUNetPlansfixedisoPatchesInVoxels", "nnUNetPlanstargetSpacingForAnisoAxis",
+ "nnUNetPlanspoolBasedOnSpacing", "nnUNetPlansfixedisoPatchesInmm", "nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_warmup': ["nnUNetPlans", "nnUNetPlansv2.1", "nnUNetPlansv2.1_big", "nnUNetPlansv2.1_verybig"],
+ 'nnUNetTrainerV2_cycleAtEnd': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_cycleAtEnd2': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_reduceMomentumDuringTraining': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_graduallyTransitionFromCEToDice': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_independentScalePerAxis': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Mish': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Ranger_lr3en4': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_GN': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_momentum098': ["nnUNetPlans", "nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_momentum09': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_DP': ["nnUNetPlansv2.1_verybig"],
+ 'nnUNetTrainerV2_DDP': ["nnUNetPlansv2.1_verybig"],
+ 'nnUNetTrainerV2_FRN': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_resample33': ["nnUNetPlansv2.3"],
+ 'nnUNetTrainerV2_O2': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ResencUNet': ["nnUNetPlans_FabiansResUNet_v2.1"],
+ 'nnUNetTrainerV2_DA2': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_allConv3x3': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ForceBD': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ForceSD': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_LReLU_slope_2en1': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_lReLU_convReLUIN': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ReLU': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ReLU_biasInSegOutput': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_ReLU_convReLUIN': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_lReLU_biasInSegOutput': ["nnUNetPlansv2.1"],
+ #'nnUNetTrainerV2_Loss_MCC': ["nnUNetPlansv2.1"],
+ #'nnUNetTrainerV2_Loss_MCCnoBG': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Loss_DicewithBG': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Loss_Dice_LR1en3': ["nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Loss_Dice': ["nnUNetPlans", "nnUNetPlansv2.1"],
+ 'nnUNetTrainerV2_Loss_DicewithBG_LR1en3': ["nnUNetPlansv2.1"],
+ # 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+ # 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+ # 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+ # 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+ # 'nnUNetTrainerV2_fp32': ["nnUNetPlansv2.1"],
+
+ }
+
+ trainers = ['nnUNetTrainer'] + ['nnUNetTrainerNewCandidate%d' % i for i in range(1, 28)] + [
+ 'nnUNetTrainerNewCandidate24_2',
+ 'nnUNetTrainerNewCandidate24_3',
+ 'nnUNetTrainerNewCandidate26_2',
+ 'nnUNetTrainerNewCandidate27_2',
+ 'nnUNetTrainerNewCandidate23_always3DDA',
+ 'nnUNetTrainerNewCandidate23_corrInit',
+ 'nnUNetTrainerNewCandidate23_noOversampling',
+ 'nnUNetTrainerNewCandidate23_softDS',
+ 'nnUNetTrainerNewCandidate23_softDS2',
+ 'nnUNetTrainerNewCandidate23_softDS3',
+ 'nnUNetTrainerNewCandidate23_softDS4',
+ 'nnUNetTrainerNewCandidate23_2_fp16',
+ 'nnUNetTrainerNewCandidate23_2',
+ 'nnUNetTrainerVer2',
+ 'nnUNetTrainerV2_2',
+ 'nnUNetTrainerV2_3',
+ 'nnUNetTrainerV2_3_CE_GDL',
+ 'nnUNetTrainerV2_3_dcTopk10',
+ 'nnUNetTrainerV2_3_dcTopk20',
+ 'nnUNetTrainerV2_3_fp16',
+ 'nnUNetTrainerV2_3_softDS4',
+ 'nnUNetTrainerV2_3_softDS4_clean',
+ 'nnUNetTrainerV2_3_softDS4_clean_improvedDA',
+ 'nnUNetTrainerV2_3_softDS4_clean_improvedDA_newElDef',
+ 'nnUNetTrainerV2_3_softDS4_radam',
+ 'nnUNetTrainerV2_3_softDS4_radam_lowerLR',
+
+ 'nnUNetTrainerV2_2_schedule',
+ 'nnUNetTrainerV2_2_schedule2',
+ 'nnUNetTrainerV2_2_clean',
+ 'nnUNetTrainerV2_2_clean_improvedDA_newElDef',
+
+ 'nnUNetTrainerV2_2_fixes', # running
+ 'nnUNetTrainerV2_BN', # running
+ 'nnUNetTrainerV2_noDeepSupervision', # running
+ 'nnUNetTrainerV2_softDeepSupervision', # running
+ 'nnUNetTrainerV2_noDataAugmentation', # running
+ 'nnUNetTrainerV2_Loss_CE', # running
+ 'nnUNetTrainerV2_Loss_CEGDL',
+ 'nnUNetTrainerV2_Loss_Dice',
+ 'nnUNetTrainerV2_Loss_DiceTopK10',
+ 'nnUNetTrainerV2_Loss_TopK10',
+ 'nnUNetTrainerV2_Adam', # running
+ 'nnUNetTrainerV2_Adam_nnUNetTrainerlr', # running
+ 'nnUNetTrainerV2_SGD_ReduceOnPlateau', # running
+ 'nnUNetTrainerV2_SGD_lr1en1', # running
+ 'nnUNetTrainerV2_SGD_lr1en3', # running
+ 'nnUNetTrainerV2_fixedNonlin', # running
+ 'nnUNetTrainerV2_GeLU', # running
+ 'nnUNetTrainerV2_3ConvPerStage',
+ 'nnUNetTrainerV2_NoNormalization',
+ 'nnUNetTrainerV2_Adam_ReduceOnPlateau',
+ 'nnUNetTrainerV2_fp16',
+ 'nnUNetTrainerV2', # see overwrite_plans
+ 'nnUNetTrainerV2_noMirroring',
+ 'nnUNetTrainerV2_momentum09',
+ 'nnUNetTrainerV2_momentum095',
+ 'nnUNetTrainerV2_momentum098',
+ 'nnUNetTrainerV2_warmup',
+ 'nnUNetTrainerV2_Loss_Dice_LR1en3',
+ 'nnUNetTrainerV2_NoNormalization_lr1en3',
+ 'nnUNetTrainerV2_Loss_Dice_squared',
+ 'nnUNetTrainerV2_newElDef',
+ 'nnUNetTrainerV2_fp32',
+ 'nnUNetTrainerV2_cycleAtEnd',
+ 'nnUNetTrainerV2_reduceMomentumDuringTraining',
+ 'nnUNetTrainerV2_graduallyTransitionFromCEToDice',
+ 'nnUNetTrainerV2_insaneDA',
+ 'nnUNetTrainerV2_independentScalePerAxis',
+ 'nnUNetTrainerV2_Mish',
+ 'nnUNetTrainerV2_Ranger_lr3en4',
+ 'nnUNetTrainerV2_cycleAtEnd2',
+ 'nnUNetTrainerV2_GN',
+ 'nnUNetTrainerV2_DP',
+ 'nnUNetTrainerV2_FRN',
+ 'nnUNetTrainerV2_resample33',
+ 'nnUNetTrainerV2_O2',
+ 'nnUNetTrainerV2_ResencUNet',
+ 'nnUNetTrainerV2_DA2',
+ 'nnUNetTrainerV2_allConv3x3',
+ 'nnUNetTrainerV2_ForceBD',
+ 'nnUNetTrainerV2_ForceSD',
+ 'nnUNetTrainerV2_ReLU',
+ 'nnUNetTrainerV2_LReLU_slope_2en1',
+ 'nnUNetTrainerV2_lReLU_convReLUIN',
+ 'nnUNetTrainerV2_ReLU_biasInSegOutput',
+ 'nnUNetTrainerV2_ReLU_convReLUIN',
+ 'nnUNetTrainerV2_lReLU_biasInSegOutput',
+ 'nnUNetTrainerV2_Loss_DicewithBG_LR1en3',
+ #'nnUNetTrainerV2_Loss_MCCnoBG',
+ 'nnUNetTrainerV2_Loss_DicewithBG',
+ # 'nnUNetTrainerV2_Loss_Dice_LR1en3',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ # 'nnUNetTrainerV2_Ranger_lr3en4',
+ ]
+
+ datasets = \
+ {"Task001_BrainTumour": ("3d_fullres", ),
+ "Task002_Heart": ("3d_fullres",),
+ #"Task024_Promise": ("3d_fullres",),
+ #"Task027_ACDC": ("3d_fullres",),
+ "Task003_Liver": ("3d_fullres", "3d_lowres"),
+ "Task004_Hippocampus": ("3d_fullres",),
+ "Task005_Prostate": ("3d_fullres",),
+ "Task006_Lung": ("3d_fullres", "3d_lowres"),
+ "Task007_Pancreas": ("3d_fullres", "3d_lowres"),
+ "Task008_HepaticVessel": ("3d_fullres", "3d_lowres"),
+ "Task009_Spleen": ("3d_fullres", "3d_lowres"),
+ "Task010_Colon": ("3d_fullres", "3d_lowres"),}
+
+ expected_validation_folder = "validation_raw"
+ alternative_validation_folder = "validation"
+ alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
+
+ interested_in = "mean"
+
+ result_per_dataset = {}
+ for d in datasets:
+ result_per_dataset[d] = {}
+ for c in datasets[d]:
+ result_per_dataset[d][c] = []
+
+ valid_trainers = []
+ all_trainers = []
+
+ with open(output_file, 'w') as f:
+ f.write("trainer,")
+ for t in datasets.keys():
+ s = t[4:7]
+ for c in datasets[t]:
+ s1 = s + "_" + c[3]
+ f.write("%s," % s1)
+ f.write("\n")
+
+ for trainer in trainers:
+ trainer_plans = [plans]
+ if trainer in overwrite_plans.keys():
+ trainer_plans = overwrite_plans[trainer]
+
+ result_per_dataset_here = {}
+ for d in datasets:
+ result_per_dataset_here[d] = {}
+
+ for p in trainer_plans:
+ name = "%s__%s" % (trainer, p)
+ all_present = True
+ all_trainers.append(name)
+
+ f.write("%s," % name)
+ for dataset in datasets.keys():
+ for configuration in datasets[dataset]:
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
+ dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ all_present = False
+ print(name, dataset, configuration, "has missing summary file")
+ if isfile(summary_file):
+ result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
+ result_per_dataset_here[dataset][configuration] = result
+ f.write("%02.4f," % result)
+ else:
+ f.write("NA,")
+ result_per_dataset_here[dataset][configuration] = 0
+
+ f.write("\n")
+
+ if True:
+ valid_trainers.append(name)
+ for d in datasets:
+ for c in datasets[d]:
+ result_per_dataset[d][c].append(result_per_dataset_here[d][c])
+
+ invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
+
+ num_valid = len(valid_trainers)
+ num_datasets = len(datasets.keys())
+ # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
+ all_res = np.zeros((num_valid, num_datasets))
+ for j, d in enumerate(datasets.keys()):
+ ks = list(result_per_dataset[d].keys())
+ tmp = result_per_dataset[d][ks[0]]
+ for k in ks[1:]:
+ for i in range(len(tmp)):
+ tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
+ all_res[:, j] = tmp
+
+ ranks_arr = np.zeros_like(all_res)
+ for d in range(ranks_arr.shape[1]):
+ temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
+ ranks = np.empty_like(temp)
+ ranks[temp] = np.arange(len(temp))
+
+ ranks_arr[:, d] = ranks
+
+ mn = np.mean(ranks_arr, 1)
+ for i in np.argsort(mn):
+ print(mn[i], valid_trainers[i])
+
+ print()
+ print(valid_trainers[np.argmin(mn)])
diff --git a/nnunet/evaluation/model_selection/rank_candidates_StructSeg.py b/nnunet/evaluation/model_selection/rank_candidates_StructSeg.py
new file mode 100644
index 0000000000000000000000000000000000000000..9858f1ff2496e39d41d17f6cae886de6a621e2a8
--- /dev/null
+++ b/nnunet/evaluation/model_selection/rank_candidates_StructSeg.py
@@ -0,0 +1,159 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import network_training_output_dir
+
+if __name__ == "__main__":
+ # run collect_all_fold0_results_and_summarize_in_one_csv.py first
+ summary_files_dir = join(network_training_output_dir, "summary_jsons_new")
+ output_file = join(network_training_output_dir, "summary_structseg_5folds.csv")
+
+ folds = (0, 1, 2, 3, 4)
+ folds_str = ""
+ for f in folds:
+ folds_str += str(f)
+
+ plans = "nnUNetPlans"
+
+ overwrite_plans = {
+ 'nnUNetTrainerV2_2': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_2_noMirror': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_lessMomentum_noMirror': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_2_structSeg_noMirror': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_2_structSeg': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_lessMomentum_noMirror_structSeg': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_FabiansResUNet_structSet_NoMirror': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+ 'nnUNetTrainerV2_FabiansResUNet_structSet': ["nnUNetPlans", "nnUNetPlans_customClip"], # r
+
+ }
+
+ trainers = ['nnUNetTrainer'] + [
+ 'nnUNetTrainerV2_2',
+ 'nnUNetTrainerV2_lessMomentum_noMirror',
+ 'nnUNetTrainerV2_2_noMirror',
+ 'nnUNetTrainerV2_2_structSeg_noMirror',
+ 'nnUNetTrainerV2_2_structSeg',
+ 'nnUNetTrainerV2_lessMomentum_noMirror_structSeg',
+ 'nnUNetTrainerV2_FabiansResUNet_structSet_NoMirror_leakyDecoder',
+ 'nnUNetTrainerV2_FabiansResUNet_structSet_NoMirror',
+ 'nnUNetTrainerV2_FabiansResUNet_structSet',
+ ]
+
+ datasets = \
+ {"Task049_StructSeg2019_Task1_HaN_OAR": ("3d_fullres", "3d_lowres", "2d"),
+ "Task050_StructSeg2019_Task2_Naso_GTV": ("3d_fullres", "3d_lowres", "2d"),
+ "Task051_StructSeg2019_Task3_Thoracic_OAR": ("3d_fullres", "3d_lowres", "2d"),
+ "Task052_StructSeg2019_Task4_Lung_GTV": ("3d_fullres", "3d_lowres", "2d"),
+}
+
+ expected_validation_folder = "validation_raw"
+ alternative_validation_folder = "validation"
+ alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
+
+ interested_in = "mean"
+
+ result_per_dataset = {}
+ for d in datasets:
+ result_per_dataset[d] = {}
+ for c in datasets[d]:
+ result_per_dataset[d][c] = []
+
+ valid_trainers = []
+ all_trainers = []
+
+ with open(output_file, 'w') as f:
+ f.write("trainer,")
+ for t in datasets.keys():
+ s = t[4:7]
+ for c in datasets[t]:
+ if len(c) > 3:
+ n = c[3]
+ else:
+ n = "2"
+ s1 = s + "_" + n
+ f.write("%s," % s1)
+ f.write("\n")
+
+ for trainer in trainers:
+ trainer_plans = [plans]
+ if trainer in overwrite_plans.keys():
+ trainer_plans = overwrite_plans[trainer]
+
+ result_per_dataset_here = {}
+ for d in datasets:
+ result_per_dataset_here[d] = {}
+
+ for p in trainer_plans:
+ name = "%s__%s" % (trainer, p)
+ all_present = True
+ all_trainers.append(name)
+
+ f.write("%s," % name)
+ for dataset in datasets.keys():
+ for configuration in datasets[dataset]:
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
+ dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ all_present = False
+ print(name, dataset, configuration, "has missing summary file")
+ if isfile(summary_file):
+ result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
+ result_per_dataset_here[dataset][configuration] = result
+ f.write("%02.4f," % result)
+ else:
+ f.write("NA,")
+ f.write("\n")
+
+ if all_present:
+ valid_trainers.append(name)
+ for d in datasets:
+ for c in datasets[d]:
+ result_per_dataset[d][c].append(result_per_dataset_here[d][c])
+
+ invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
+
+ num_valid = len(valid_trainers)
+ num_datasets = len(datasets.keys())
+ # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
+ all_res = np.zeros((num_valid, num_datasets))
+ for j, d in enumerate(datasets.keys()):
+ ks = list(result_per_dataset[d].keys())
+ tmp = result_per_dataset[d][ks[0]]
+ for k in ks[1:]:
+ for i in range(len(tmp)):
+ tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
+ all_res[:, j] = tmp
+
+ ranks_arr = np.zeros_like(all_res)
+ for d in range(ranks_arr.shape[1]):
+ temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
+ ranks = np.empty_like(temp)
+ ranks[temp] = np.arange(len(temp))
+
+ ranks_arr[:, d] = ranks
+
+ mn = np.mean(ranks_arr, 1)
+ for i in np.argsort(mn):
+ print(mn[i], valid_trainers[i])
+
+ print()
+ print(valid_trainers[np.argmin(mn)])
diff --git a/nnunet/evaluation/model_selection/rank_candidates_cascade.py b/nnunet/evaluation/model_selection/rank_candidates_cascade.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a38268306df3bce593e540ba35a8f48e6bfea07
--- /dev/null
+++ b/nnunet/evaluation/model_selection/rank_candidates_cascade.py
@@ -0,0 +1,164 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import network_training_output_dir
+
+if __name__ == "__main__":
+ # run collect_all_fold0_results_and_summarize_in_one_csv.py first
+ summary_files_dir = join(network_training_output_dir, "summary_jsons_fold0_new")
+ output_file = join(network_training_output_dir, "summary_cascade.csv")
+
+ folds = (0, )
+ folds_str = ""
+ for f in folds:
+ folds_str += str(f)
+
+ plans = "nnUNetPlansv2.1"
+
+ overwrite_plans = {
+ 'nnUNetTrainerCascadeFullRes': ['nnUNetPlans'],
+ }
+
+ trainers = [
+ 'nnUNetTrainerCascadeFullRes',
+ 'nnUNetTrainerV2CascadeFullRes_EducatedGuess',
+ 'nnUNetTrainerV2CascadeFullRes_EducatedGuess2',
+ 'nnUNetTrainerV2CascadeFullRes_EducatedGuess3',
+ 'nnUNetTrainerV2CascadeFullRes_lowerLR',
+ 'nnUNetTrainerV2CascadeFullRes',
+ 'nnUNetTrainerV2CascadeFullRes_noConnComp',
+ 'nnUNetTrainerV2CascadeFullRes_shorter_lowerLR',
+ 'nnUNetTrainerV2CascadeFullRes_shorter',
+ 'nnUNetTrainerV2CascadeFullRes_smallerBinStrel',
+ #'',
+ #'',
+ #'',
+ #'',
+ #'',
+ #'',
+ ]
+
+ datasets = \
+ {
+ "Task003_Liver": ("3d_cascade_fullres", ),
+ "Task006_Lung": ("3d_cascade_fullres", ),
+ "Task007_Pancreas": ("3d_cascade_fullres", ),
+ "Task008_HepaticVessel": ("3d_cascade_fullres", ),
+ "Task009_Spleen": ("3d_cascade_fullres", ),
+ "Task010_Colon": ("3d_cascade_fullres", ),
+ "Task017_AbdominalOrganSegmentation": ("3d_cascade_fullres", ),
+ #"Task029_LITS": ("3d_cascade_fullres", ),
+ "Task048_KiTS_clean": ("3d_cascade_fullres", ),
+ "Task055_SegTHOR": ("3d_cascade_fullres", ),
+ "Task056_VerSe": ("3d_cascade_fullres", ),
+ #"": ("3d_cascade_fullres", ),
+ }
+
+ expected_validation_folder = "validation_raw"
+ alternative_validation_folder = "validation"
+ alternative_alternative_validation_folder = "validation_tiledTrue_doMirror_True"
+
+ interested_in = "mean"
+
+ result_per_dataset = {}
+ for d in datasets:
+ result_per_dataset[d] = {}
+ for c in datasets[d]:
+ result_per_dataset[d][c] = []
+
+ valid_trainers = []
+ all_trainers = []
+
+ with open(output_file, 'w') as f:
+ f.write("trainer,")
+ for t in datasets.keys():
+ s = t[4:7]
+ for c in datasets[t]:
+ s1 = s + "_" + c[3]
+ f.write("%s," % s1)
+ f.write("\n")
+
+ for trainer in trainers:
+ trainer_plans = [plans]
+ if trainer in overwrite_plans.keys():
+ trainer_plans = overwrite_plans[trainer]
+
+ result_per_dataset_here = {}
+ for d in datasets:
+ result_per_dataset_here[d] = {}
+
+ for p in trainer_plans:
+ name = "%s__%s" % (trainer, p)
+ all_present = True
+ all_trainers.append(name)
+
+ f.write("%s," % name)
+ for dataset in datasets.keys():
+ for configuration in datasets[dataset]:
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, expected_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (dataset, configuration, trainer, p, alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ summary_file = join(summary_files_dir, "%s__%s__%s__%s__%s__%s.json" % (
+ dataset, configuration, trainer, p, alternative_alternative_validation_folder, folds_str))
+ if not isfile(summary_file):
+ all_present = False
+ print(name, dataset, configuration, "has missing summary file")
+ if isfile(summary_file):
+ result = load_json(summary_file)['results'][interested_in]['mean']['Dice']
+ result_per_dataset_here[dataset][configuration] = result
+ f.write("%02.4f," % result)
+ else:
+ f.write("NA,")
+ result_per_dataset_here[dataset][configuration] = 0
+
+ f.write("\n")
+
+ if True:
+ valid_trainers.append(name)
+ for d in datasets:
+ for c in datasets[d]:
+ result_per_dataset[d][c].append(result_per_dataset_here[d][c])
+
+ invalid_trainers = [i for i in all_trainers if i not in valid_trainers]
+
+ num_valid = len(valid_trainers)
+ num_datasets = len(datasets.keys())
+ # create an array that is trainer x dataset. If more than one configuration is there then use the best metric across the two
+ all_res = np.zeros((num_valid, num_datasets))
+ for j, d in enumerate(datasets.keys()):
+ ks = list(result_per_dataset[d].keys())
+ tmp = result_per_dataset[d][ks[0]]
+ for k in ks[1:]:
+ for i in range(len(tmp)):
+ tmp[i] = max(tmp[i], result_per_dataset[d][k][i])
+ all_res[:, j] = tmp
+
+ ranks_arr = np.zeros_like(all_res)
+ for d in range(ranks_arr.shape[1]):
+ temp = np.argsort(all_res[:, d])[::-1] # inverse because we want the highest dice to be rank0
+ ranks = np.empty_like(temp)
+ ranks[temp] = np.arange(len(temp))
+
+ ranks_arr[:, d] = ranks
+
+ mn = np.mean(ranks_arr, 1)
+ for i in np.argsort(mn):
+ print(mn[i], valid_trainers[i])
+
+ print()
+ print(valid_trainers[np.argmin(mn)])
diff --git a/nnunet/evaluation/model_selection/summarize_results_in_one_json.py b/nnunet/evaluation/model_selection/summarize_results_in_one_json.py
new file mode 100644
index 0000000000000000000000000000000000000000..7989e0966466ee6524efd848c5d869569095cab6
--- /dev/null
+++ b/nnunet/evaluation/model_selection/summarize_results_in_one_json.py
@@ -0,0 +1,236 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+from nnunet.evaluation.add_mean_dice_to_json import foreground_mean
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import network_training_output_dir
+import numpy as np
+
+
+def summarize(tasks, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'),
+ output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)):
+ maybe_mkdir_p(output_dir)
+
+ if len(tasks) == 1 and tasks[0] == "all":
+ tasks = list(range(999))
+ else:
+ tasks = [int(i) for i in tasks]
+
+ for model in models:
+ for t in tasks:
+ t = int(t)
+ if not isdir(join(network_training_output_dir, model)):
+ continue
+ task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False)
+ if len(task_name) != 1:
+ print("did not find unique output folder for network %s and task %s" % (model, t))
+ continue
+ task_name = task_name[0]
+ out_dir_task = join(network_training_output_dir, model, task_name)
+
+ model_trainers = subdirs(out_dir_task, join=False)
+ for trainer in model_trainers:
+ if trainer.startswith("fold"):
+ continue
+ out_dir = join(out_dir_task, trainer)
+
+ validation_folders = []
+ for fld in folds:
+ d = join(out_dir, "fold%d"%fld)
+ if not isdir(d):
+ d = join(out_dir, "fold_%d"%fld)
+ if not isdir(d):
+ break
+ validation_folders += subfolders(d, prefix="validation", join=False)
+
+ for v in validation_folders:
+ ok = True
+ metrics = OrderedDict()
+ for fld in folds:
+ d = join(out_dir, "fold%d"%fld)
+ if not isdir(d):
+ d = join(out_dir, "fold_%d"%fld)
+ if not isdir(d):
+ ok = False
+ break
+ validation_folder = join(d, v)
+
+ if not isfile(join(validation_folder, "summary.json")):
+ print("summary.json missing for net %s task %s fold %d" % (model, task_name, fld))
+ ok = False
+ break
+
+ metrics_tmp = load_json(join(validation_folder, "summary.json"))["results"]["mean"]
+ for l in metrics_tmp.keys():
+ if metrics.get(l) is None:
+ metrics[l] = OrderedDict()
+ for m in metrics_tmp[l].keys():
+ if metrics[l].get(m) is None:
+ metrics[l][m] = []
+ metrics[l][m].append(metrics_tmp[l][m])
+ if ok:
+ for l in metrics.keys():
+ for m in metrics[l].keys():
+ assert len(metrics[l][m]) == len(folds)
+ metrics[l][m] = np.mean(metrics[l][m])
+ json_out = OrderedDict()
+ json_out["results"] = OrderedDict()
+ json_out["results"]["mean"] = metrics
+ json_out["task"] = task_name
+ json_out["description"] = model + " " + task_name + " all folds summary"
+ json_out["name"] = model + " " + task_name + " all folds summary"
+ json_out["experiment_name"] = model
+ save_json(json_out, join(out_dir, "summary_allFolds__%s.json" % v))
+ save_json(json_out, join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v)))
+ foreground_mean(join(out_dir, "summary_allFolds__%s.json" % v))
+ foreground_mean(join(output_dir, "%s__%s__%s__%s.json" % (task_name, model, trainer, v)))
+
+
+def summarize2(task_ids, models=('2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'),
+ output_dir=join(network_training_output_dir, "summary_jsons"), folds=(0, 1, 2, 3, 4)):
+ maybe_mkdir_p(output_dir)
+
+ if len(task_ids) == 1 and task_ids[0] == "all":
+ task_ids = list(range(999))
+ else:
+ task_ids = [int(i) for i in task_ids]
+
+ for model in models:
+ for t in task_ids:
+ if not isdir(join(network_training_output_dir, model)):
+ continue
+ task_name = subfolders(join(network_training_output_dir, model), prefix="Task%03.0d" % t, join=False)
+ if len(task_name) != 1:
+ print("did not find unique output folder for network %s and task %s" % (model, t))
+ continue
+ task_name = task_name[0]
+ out_dir_task = join(network_training_output_dir, model, task_name)
+
+ model_trainers = subdirs(out_dir_task, join=False)
+ for trainer in model_trainers:
+ if trainer.startswith("fold"):
+ continue
+ out_dir = join(out_dir_task, trainer)
+
+ validation_folders = []
+ for fld in folds:
+ fold_output_dir = join(out_dir, "fold_%d"%fld)
+ if not isdir(fold_output_dir):
+ continue
+ validation_folders += subfolders(fold_output_dir, prefix="validation", join=False)
+
+ validation_folders = np.unique(validation_folders)
+
+ for v in validation_folders:
+ ok = True
+ metrics = OrderedDict()
+ metrics['mean'] = OrderedDict()
+ metrics['median'] = OrderedDict()
+ metrics['all'] = OrderedDict()
+ for fld in folds:
+ fold_output_dir = join(out_dir, "fold_%d"%fld)
+
+ if not isdir(fold_output_dir):
+ print("fold missing", model, task_name, trainer, fld)
+ ok = False
+ break
+ validation_folder = join(fold_output_dir, v)
+
+ if not isdir(validation_folder):
+ print("validation folder missing", model, task_name, trainer, fld, v)
+ ok = False
+ break
+
+ if not isfile(join(validation_folder, "summary.json")):
+ print("summary.json missing", model, task_name, trainer, fld, v)
+ ok = False
+ break
+
+ all_metrics = load_json(join(validation_folder, "summary.json"))["results"]
+ # we now need to get the mean and median metrics. We use the mean metrics just to get the
+ # names of computed metics, we ignore the precomputed mean and do it ourselfes again
+ mean_metrics = all_metrics["mean"]
+ all_labels = [i for i in list(mean_metrics.keys()) if i != "mean"]
+
+ if len(all_labels) == 0: print(v, fld); break
+
+ all_metrics_names = list(mean_metrics[all_labels[0]].keys())
+ for l in all_labels:
+ # initialize the data structure, no values are copied yet
+ for k in ['mean', 'median', 'all']:
+ if metrics[k].get(l) is None:
+ metrics[k][l] = OrderedDict()
+ for m in all_metrics_names:
+ if metrics['all'][l].get(m) is None:
+ metrics['all'][l][m] = []
+ for entry in all_metrics['all']:
+ for l in all_labels:
+ for m in all_metrics_names:
+ metrics['all'][l][m].append(entry[l][m])
+ # now compute mean and median
+ for l in metrics['all'].keys():
+ for m in metrics['all'][l].keys():
+ metrics['mean'][l][m] = np.nanmean(metrics['all'][l][m])
+ metrics['median'][l][m] = np.nanmedian(metrics['all'][l][m])
+ if ok:
+ fold_string = ""
+ for f in folds:
+ fold_string += str(f)
+ json_out = OrderedDict()
+ json_out["results"] = OrderedDict()
+ json_out["results"]["mean"] = metrics['mean']
+ json_out["results"]["median"] = metrics['median']
+ json_out["task"] = task_name
+ json_out["description"] = model + " " + task_name + "summary folds" + str(folds)
+ json_out["name"] = model + " " + task_name + "summary folds" + str(folds)
+ json_out["experiment_name"] = model
+ save_json(json_out, join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string)))
+ foreground_mean2(join(output_dir, "%s__%s__%s__%s__%s.json" % (task_name, model, trainer, v, fold_string)))
+
+
+def foreground_mean2(filename):
+ with open(filename, 'r') as f:
+ res = json.load(f)
+ class_ids = np.array([int(i) for i in res['results']['mean'].keys() if (i != 'mean') and i != '0'])
+
+ metric_names = res['results']['mean']['1'].keys()
+ res['results']['mean']["mean"] = OrderedDict()
+ res['results']['median']["mean"] = OrderedDict()
+ for m in metric_names:
+ foreground_values = [res['results']['mean'][str(i)][m] for i in class_ids]
+ res['results']['mean']["mean"][m] = np.nanmean(foreground_values)
+ foreground_values = [res['results']['median'][str(i)][m] for i in class_ids]
+ res['results']['median']["mean"][m] = np.nanmean(foreground_values)
+ with open(filename, 'w') as f:
+ json.dump(res, f, indent=4, sort_keys=True)
+
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser(usage="This is intended to identify the best model based on the five fold "
+ "cross-validation. Running this script requires alle models to have been run "
+ "already. This script will summarize the results of the five folds of all "
+ "models in one json each for easy interpretability")
+ parser.add_argument("-t", '--task_ids', nargs="+", required=True, help="task id. can be 'all'")
+ parser.add_argument("-f", '--folds', nargs="+", required=False, type=int, default=[0, 1, 2, 3, 4])
+ parser.add_argument("-m", '--models', nargs="+", required=False, default=['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'])
+
+ args = parser.parse_args()
+ tasks = args.task_ids
+ models = args.models
+
+ folds = args.folds
+ summarize2(tasks, models, folds=folds, output_dir=join(network_training_output_dir, "summary_jsons_new"))
+
diff --git a/nnunet/evaluation/model_selection/summarize_results_with_plans.py b/nnunet/evaluation/model_selection/summarize_results_with_plans.py
new file mode 100644
index 0000000000000000000000000000000000000000..964e9b2f7b6dd4f7413c7882601e796568819785
--- /dev/null
+++ b/nnunet/evaluation/model_selection/summarize_results_with_plans.py
@@ -0,0 +1,110 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import os
+from nnunet.evaluation.model_selection.summarize_results_in_one_json import summarize
+from nnunet.paths import network_training_output_dir
+import numpy as np
+
+
+def list_to_string(l, delim=","):
+ st = "%03.3f" % l[0]
+ for i in l[1:]:
+ st += delim + "%03.3f" % i
+ return st
+
+
+def write_plans_to_file(f, plans_file, stage=0, do_linebreak_at_end=True, override_name=None):
+ a = load_pickle(plans_file)
+ stages = list(a['plans_per_stage'].keys())
+ stages.sort()
+ patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'],
+ a['plans_per_stage'][stages[stage]]['current_spacing'])]
+ median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'],
+ a['plans_per_stage'][stages[stage]]['current_spacing'])]
+ if override_name is None:
+ f.write(plans_file.split("/")[-2] + "__" + plans_file.split("/")[-1])
+ else:
+ f.write(override_name)
+ f.write(";%d" % stage)
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size']))
+ f.write(";%s" % list_to_string(patch_size_in_mm))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels']))
+ f.write(";%s" % list_to_string(median_patient_size_in_mm))
+ f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['current_spacing']))
+ f.write(";%s" % list_to_string(a['plans_per_stage'][stages[stage]]['original_spacing']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes']))
+ if do_linebreak_at_end:
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ summarize((1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 24, 27), output_dir=join(network_training_output_dir, "summary_fold0"), folds=(0,))
+ base_dir = os.environ['RESULTS_FOLDER']
+ nnunets = ['nnUNetV2', 'nnUNetV2_zspacing']
+ task_ids = list(range(99))
+ with open("summary.csv", 'w') as f:
+ f.write("identifier;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes;patient_dc;global_dc\n")
+ for i in task_ids:
+ for nnunet in nnunets:
+ try:
+ summary_folder = join(base_dir, nnunet, "summary_fold0")
+ if isdir(summary_folder):
+ summary_files = subfiles(summary_folder, join=False, prefix="Task%03.0d_" % i, suffix=".json", sort=True)
+ for s in summary_files:
+ tmp = s.split("__")
+ trainer = tmp[2]
+
+ expected_output_folder = join(base_dir, nnunet, tmp[1], tmp[0], tmp[2].split(".")[0])
+ name = tmp[0] + "__" + nnunet + "__" + tmp[1] + "__" + tmp[2].split(".")[0]
+ global_dice_json = join(base_dir, nnunet, tmp[1], tmp[0], tmp[2].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json")
+
+ if not isdir(expected_output_folder) or len(tmp) > 3:
+ if len(tmp) == 2:
+ continue
+ expected_output_folder = join(base_dir, nnunet, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0])
+ name = tmp[0] + "__" + nnunet + "__" + tmp[1] + "__" + tmp[2] + "__" + tmp[3].split(".")[0]
+ global_dice_json = join(base_dir, nnunet, tmp[1], tmp[0], tmp[2] + "__" + tmp[3].split(".")[0], "fold_0", "validation_tiledTrue_doMirror_True", "global_dice.json")
+
+ assert isdir(expected_output_folder), "expected output dir not found"
+ plans_file = join(expected_output_folder, "plans.pkl")
+ assert isfile(plans_file)
+
+ plans = load_pickle(plans_file)
+ num_stages = len(plans['plans_per_stage'])
+ if num_stages > 1 and tmp[1] == "3d_fullres":
+ stage = 1
+ elif (num_stages == 1 and tmp[1] == "3d_fullres") or tmp[1] == "3d_lowres":
+ stage = 0
+ else:
+ print("skipping", s)
+ continue
+
+ g_dc = load_json(global_dice_json)
+ mn_glob_dc = np.mean(list(g_dc.values()))
+
+ write_plans_to_file(f, plans_file, stage, False, name)
+ # now read and add result to end of line
+ results = load_json(join(summary_folder, s))
+ mean_dc = results['results']['mean']['mean']['Dice']
+ f.write(";%03.3f" % mean_dc)
+ f.write(";%03.3f\n" % mn_glob_dc)
+ print(name, mean_dc)
+ except Exception as e:
+ print(e)
diff --git a/nnunet/evaluation/region_based_evaluation.py b/nnunet/evaluation/region_based_evaluation.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e9b0cbfd0d3f466a2139ff113190fa75d1d57b
--- /dev/null
+++ b/nnunet/evaluation/region_based_evaluation.py
@@ -0,0 +1,115 @@
+from copy import deepcopy
+from multiprocessing.pool import Pool
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from medpy import metric
+import SimpleITK as sitk
+import numpy as np
+from nnunet.configuration import default_num_threads
+from nnunet.postprocessing.consolidate_postprocessing import collect_cv_niftis
+
+
+def get_brats_regions():
+ """
+ this is only valid for the brats data in here where the labels are 1, 2, and 3. The original brats data have a
+ different labeling convention!
+ :return:
+ """
+ regions = {
+ "whole tumor": (1, 2, 3),
+ "tumor core": (2, 3),
+ "enhancing tumor": (3,)
+ }
+ return regions
+
+
+def get_KiTS_regions():
+ regions = {
+ "kidney incl tumor": (1, 2),
+ "tumor": (2,)
+ }
+ return regions
+
+
+def create_region_from_mask(mask, join_labels: tuple):
+ mask_new = np.zeros_like(mask, dtype=np.uint8)
+ for l in join_labels:
+ mask_new[mask == l] = 1
+ return mask_new
+
+
+def evaluate_case(file_pred: str, file_gt: str, regions):
+ image_gt = sitk.GetArrayFromImage(sitk.ReadImage(file_gt))
+ image_pred = sitk.GetArrayFromImage(sitk.ReadImage(file_pred))
+ results = []
+ for r in regions:
+ mask_pred = create_region_from_mask(image_pred, r)
+ mask_gt = create_region_from_mask(image_gt, r)
+ dc = np.nan if np.sum(mask_gt) == 0 and np.sum(mask_pred) == 0 else metric.dc(mask_pred, mask_gt)
+ results.append(dc)
+ return results
+
+
+def evaluate_regions(folder_predicted: str, folder_gt: str, regions: dict, processes=default_num_threads):
+ region_names = list(regions.keys())
+ files_in_pred = subfiles(folder_predicted, suffix='.nii.gz', join=False)
+ files_in_gt = subfiles(folder_gt, suffix='.nii.gz', join=False)
+ have_no_gt = [i for i in files_in_pred if i not in files_in_gt]
+ assert len(have_no_gt) == 0, "Some files in folder_predicted have not ground truth in folder_gt"
+ have_no_pred = [i for i in files_in_gt if i not in files_in_pred]
+ if len(have_no_pred) > 0:
+ print("WARNING! Some files in folder_gt were not predicted (not present in folder_predicted)!")
+
+ files_in_gt.sort()
+ files_in_pred.sort()
+
+ # run for all cases
+ full_filenames_gt = [join(folder_gt, i) for i in files_in_pred]
+ full_filenames_pred = [join(folder_predicted, i) for i in files_in_pred]
+
+ p = Pool(processes)
+ res = p.starmap(evaluate_case, zip(full_filenames_pred, full_filenames_gt, [list(regions.values())] * len(files_in_gt)))
+ p.close()
+ p.join()
+
+ all_results = {r: [] for r in region_names}
+ with open(join(folder_predicted, 'summary.csv'), 'w') as f:
+ f.write("casename")
+ for r in region_names:
+ f.write(",%s" % r)
+ f.write("\n")
+ for i in range(len(files_in_pred)):
+ f.write(files_in_pred[i][:-7])
+ result_here = res[i]
+ for k, r in enumerate(region_names):
+ dc = result_here[k]
+ f.write(",%02.4f" % dc)
+ all_results[r].append(dc)
+ f.write("\n")
+
+ f.write('mean')
+ for r in region_names:
+ f.write(",%02.4f" % np.nanmean(all_results[r]))
+ f.write("\n")
+ f.write('median')
+ for r in region_names:
+ f.write(",%02.4f" % np.nanmedian(all_results[r]))
+ f.write("\n")
+
+ f.write('mean (nan is 1)')
+ for r in region_names:
+ tmp = np.array(all_results[r])
+ tmp[np.isnan(tmp)] = 1
+ f.write(",%02.4f" % np.mean(tmp))
+ f.write("\n")
+ f.write('median (nan is 1)')
+ for r in region_names:
+ tmp = np.array(all_results[r])
+ tmp[np.isnan(tmp)] = 1
+ f.write(",%02.4f" % np.median(tmp))
+ f.write("\n")
+
+
+if __name__ == '__main__':
+ collect_cv_niftis('./', './cv_niftis')
+ evaluate_regions('./cv_niftis/', './gt_niftis/', get_brats_regions())
diff --git a/nnunet/evaluation/surface_dice.py b/nnunet/evaluation/surface_dice.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce5ffdee35071fe65bd81ca87f60b97401c8801
--- /dev/null
+++ b/nnunet/evaluation/surface_dice.py
@@ -0,0 +1,57 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from medpy.metric.binary import __surface_distances
+
+
+def normalized_surface_dice(a: np.ndarray, b: np.ndarray, threshold: float, spacing: tuple = None, connectivity=1):
+ """
+ This implementation differs from the official surface dice implementation! These two are not comparable!!!!!
+
+ The normalized surface dice is symmetric, so it should not matter whether a or b is the reference image
+
+ This implementation natively supports 2D and 3D images. Whether other dimensions are supported depends on the
+ __surface_distances implementation in medpy
+
+ :param a: image 1, must have the same shape as b
+ :param b: image 2, must have the same shape as a
+ :param threshold: distances below this threshold will be counted as true positives. Threshold is in mm, not voxels!
+ (if spacing = (1, 1(, 1)) then one voxel=1mm so the threshold is effectively in voxels)
+ must be a tuple of len dimension(a)
+ :param spacing: how many mm is one voxel in reality? Can be left at None, we then assume an isotropic spacing of 1mm
+ :param connectivity: see scipy.ndimage.generate_binary_structure for more information. I suggest you leave that
+ one alone
+ :return:
+ """
+ assert all([i == j for i, j in zip(a.shape, b.shape)]), "a and b must have the same shape. a.shape= %s, " \
+ "b.shape= %s" % (str(a.shape), str(b.shape))
+ if spacing is None:
+ spacing = tuple([1 for _ in range(len(a.shape))])
+ a_to_b = __surface_distances(a, b, spacing, connectivity)
+ b_to_a = __surface_distances(b, a, spacing, connectivity)
+
+ numel_a = len(a_to_b)
+ numel_b = len(b_to_a)
+
+ tp_a = np.sum(a_to_b <= threshold) / numel_a
+ tp_b = np.sum(b_to_a <= threshold) / numel_b
+
+ fp = np.sum(a_to_b > threshold) / numel_a
+ fn = np.sum(b_to_a > threshold) / numel_b
+
+ dc = (tp_a + tp_b) / (tp_a + tp_b + fp + fn + 1e-8) # 1e-8 just so that we don't get div by 0
+ return dc
+
diff --git a/nnunet/experiment_planning/DatasetAnalyzer.py b/nnunet/experiment_planning/DatasetAnalyzer.py
new file mode 100644
index 0000000000000000000000000000000000000000..31e84d8aef186e34f6b23fb63124dbba61d96a57
--- /dev/null
+++ b/nnunet/experiment_planning/DatasetAnalyzer.py
@@ -0,0 +1,256 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from multiprocessing import Pool
+
+from nnunet.configuration import default_num_threads
+from nnunet.paths import nnUNet_raw_data, nnUNet_cropped_data
+import numpy as np
+import pickle
+from nnunet.preprocessing.cropping import get_patient_identifiers_from_cropped_files
+from skimage.morphology import label
+from collections import OrderedDict
+
+
+class DatasetAnalyzer(object):
+ def __init__(self, folder_with_cropped_data, overwrite=True, num_processes=default_num_threads):
+ """
+ :param folder_with_cropped_data:
+ :param overwrite: If True then precomputed values will not be used and instead recomputed from the data.
+ False will allow loading of precomputed values. This may be dangerous though if some of the code of this class
+ was changed, therefore the default is True.
+ """
+ self.num_processes = num_processes
+ self.overwrite = overwrite
+ self.folder_with_cropped_data = folder_with_cropped_data
+ self.sizes = self.spacings = None
+ self.patient_identifiers = get_patient_identifiers_from_cropped_files(self.folder_with_cropped_data)
+ assert isfile(join(self.folder_with_cropped_data, "dataset.json")), \
+ "dataset.json needs to be in folder_with_cropped_data"
+ self.props_per_case_file = join(self.folder_with_cropped_data, "props_per_case.pkl")
+ self.intensityproperties_file = join(self.folder_with_cropped_data, "intensityproperties.pkl")
+
+ def load_properties_of_cropped(self, case_identifier):
+ with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f:
+ properties = pickle.load(f)
+ return properties
+
+ @staticmethod
+ def _check_if_all_in_one_region(seg, regions):
+ res = OrderedDict()
+ for r in regions:
+ new_seg = np.zeros(seg.shape)
+ for c in r:
+ new_seg[seg == c] = 1
+ labelmap, numlabels = label(new_seg, return_num=True)
+ if numlabels != 1:
+ res[tuple(r)] = False
+ else:
+ res[tuple(r)] = True
+ return res
+
+ @staticmethod
+ def _collect_class_and_region_sizes(seg, all_classes, vol_per_voxel):
+ volume_per_class = OrderedDict()
+ region_volume_per_class = OrderedDict()
+ for c in all_classes:
+ region_volume_per_class[c] = []
+ volume_per_class[c] = np.sum(seg == c) * vol_per_voxel
+ labelmap, numregions = label(seg == c, return_num=True)
+ for l in range(1, numregions + 1):
+ region_volume_per_class[c].append(np.sum(labelmap == l) * vol_per_voxel)
+ return volume_per_class, region_volume_per_class
+
+ def _get_unique_labels(self, patient_identifier):
+ seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
+ unique_classes = np.unique(seg)
+ return unique_classes
+
+ def _load_seg_analyze_classes(self, patient_identifier, all_classes):
+ """
+ 1) what class is in this training case?
+ 2) what is the size distribution for each class?
+ 3) what is the region size of each class?
+ 4) check if all in one region
+ :return:
+ """
+ seg = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data'][-1]
+ pkl = load_pickle(join(self.folder_with_cropped_data, patient_identifier) + ".pkl")
+ vol_per_voxel = np.prod(pkl['itk_spacing'])
+
+ # ad 1)
+ unique_classes = np.unique(seg)
+
+ # 4) check if all in one region
+ regions = list()
+ regions.append(list(all_classes))
+ for c in all_classes:
+ regions.append((c, ))
+
+ all_in_one_region = self._check_if_all_in_one_region(seg, regions)
+
+ # 2 & 3) region sizes
+ volume_per_class, region_sizes = self._collect_class_and_region_sizes(seg, all_classes, vol_per_voxel)
+
+ return unique_classes, all_in_one_region, volume_per_class, region_sizes
+
+ def get_classes(self):
+ datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
+ return datasetjson['labels']
+
+ def analyse_segmentations(self):
+ class_dct = self.get_classes()
+
+ if self.overwrite or not isfile(self.props_per_case_file):
+ p = Pool(self.num_processes)
+ res = p.map(self._get_unique_labels, self.patient_identifiers)
+ p.close()
+ p.join()
+
+ props_per_patient = OrderedDict()
+ for p, unique_classes in \
+ zip(self.patient_identifiers, res):
+ props = dict()
+ props['has_classes'] = unique_classes
+ props_per_patient[p] = props
+
+ save_pickle(props_per_patient, self.props_per_case_file)
+ else:
+ props_per_patient = load_pickle(self.props_per_case_file)
+ return class_dct, props_per_patient
+
+ def get_sizes_and_spacings_after_cropping(self):
+ sizes = []
+ spacings = []
+ # for c in case_identifiers:
+ for c in self.patient_identifiers:
+ properties = self.load_properties_of_cropped(c)
+ sizes.append(properties["size_after_cropping"])
+ spacings.append(properties["original_spacing"])
+
+ return sizes, spacings
+
+ def get_modalities(self):
+ datasetjson = load_json(join(self.folder_with_cropped_data, "dataset.json"))
+ modalities = datasetjson["modality"]
+ modalities = {int(k): modalities[k] for k in modalities.keys()}
+ return modalities
+
+ def get_size_reduction_by_cropping(self):
+ size_reduction = OrderedDict()
+ for p in self.patient_identifiers:
+ props = self.load_properties_of_cropped(p)
+ shape_before_crop = props["original_size_of_raw_data"]
+ shape_after_crop = props['size_after_cropping']
+ size_red = np.prod(shape_after_crop) / np.prod(shape_before_crop)
+ size_reduction[p] = size_red
+ return size_reduction
+
+ def _get_voxels_in_foreground(self, patient_identifier, modality_id):
+ all_data = np.load(join(self.folder_with_cropped_data, patient_identifier) + ".npz")['data']
+ modality = all_data[modality_id]
+ mask = all_data[-1] > 0
+ voxels = list(modality[mask][::10]) # no need to take every voxel
+ return voxels
+
+ @staticmethod
+ def _compute_stats(voxels):
+ if len(voxels) == 0:
+ return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan
+ median = np.median(voxels)
+ mean = np.mean(voxels)
+ sd = np.std(voxels)
+ mn = np.min(voxels)
+ mx = np.max(voxels)
+ percentile_99_5 = np.percentile(voxels, 99.5)
+ percentile_00_5 = np.percentile(voxels, 00.5)
+ return median, mean, sd, mn, mx, percentile_99_5, percentile_00_5
+
+ def collect_intensity_properties(self, num_modalities):
+ if self.overwrite or not isfile(self.intensityproperties_file):
+ p = Pool(self.num_processes)
+
+ results = OrderedDict()
+ for mod_id in range(num_modalities):
+ results[mod_id] = OrderedDict()
+ v = p.starmap(self._get_voxels_in_foreground, zip(self.patient_identifiers,
+ [mod_id] * len(self.patient_identifiers)))
+
+ w = []
+ for iv in v:
+ w += iv
+
+ median, mean, sd, mn, mx, percentile_99_5, percentile_00_5 = self._compute_stats(w)
+
+ local_props = p.map(self._compute_stats, v)
+ props_per_case = OrderedDict()
+ for i, pat in enumerate(self.patient_identifiers):
+ props_per_case[pat] = OrderedDict()
+ props_per_case[pat]['median'] = local_props[i][0]
+ props_per_case[pat]['mean'] = local_props[i][1]
+ props_per_case[pat]['sd'] = local_props[i][2]
+ props_per_case[pat]['mn'] = local_props[i][3]
+ props_per_case[pat]['mx'] = local_props[i][4]
+ props_per_case[pat]['percentile_99_5'] = local_props[i][5]
+ props_per_case[pat]['percentile_00_5'] = local_props[i][6]
+
+ results[mod_id]['local_props'] = props_per_case
+ results[mod_id]['median'] = median
+ results[mod_id]['mean'] = mean
+ results[mod_id]['sd'] = sd
+ results[mod_id]['mn'] = mn
+ results[mod_id]['mx'] = mx
+ results[mod_id]['percentile_99_5'] = percentile_99_5
+ results[mod_id]['percentile_00_5'] = percentile_00_5
+
+ p.close()
+ p.join()
+ save_pickle(results, self.intensityproperties_file)
+ else:
+ results = load_pickle(self.intensityproperties_file)
+ return results
+
+ def analyze_dataset(self, collect_intensityproperties=True):
+ # get all spacings and sizes
+ sizes, spacings = self.get_sizes_and_spacings_after_cropping()
+
+ # get all classes and what classes are in what patients
+ # class min size
+ # region size per class
+ classes = self.get_classes()
+ all_classes = [int(i) for i in classes.keys() if int(i) > 0]
+
+ # modalities
+ modalities = self.get_modalities()
+
+ # collect intensity information
+ if collect_intensityproperties:
+ intensityproperties = self.collect_intensity_properties(len(modalities))
+ else:
+ intensityproperties = None
+
+ # size reduction by cropping
+ size_reductions = self.get_size_reduction_by_cropping()
+
+ dataset_properties = dict()
+ dataset_properties['all_sizes'] = sizes
+ dataset_properties['all_spacings'] = spacings
+ dataset_properties['all_classes'] = all_classes
+ dataset_properties['modalities'] = modalities # {idx: modality name}
+ dataset_properties['intensityproperties'] = intensityproperties
+ dataset_properties['size_reductions'] = size_reductions # {patient_id: size_reduction}
+
+ save_pickle(dataset_properties, join(self.folder_with_cropped_data, "dataset_properties.pkl"))
+ return dataset_properties
diff --git a/nnunet/experiment_planning/__init__.py b/nnunet/experiment_planning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/experiment_planning/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/__init__.py b/nnunet/experiment_planning/alternative_experiment_planning/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_11GB.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_11GB.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc5dbe9f4fb1ce17c6acc01f6480ff25fcfee94d
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_11GB.py
@@ -0,0 +1,124 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_11GB(ExperimentPlanner3D_v21):
+ """
+ Same as ExperimentPlanner3D_v21, but designed to fill a RTX2080 ti (11GB) in fp16
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_11GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1_big"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_big_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ We need to adapt ref
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+ # use_this_for_batch_size_computation_3D = 520000000 # 505789440
+ # typical ExperimentPlanner3D_v21 configurations use ~8.6GB, but on a 2080ti we have 11. Allow for more space
+ # to be used
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D * 11 / 8 # 8 instead of 8.5 because YOLO
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ # print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_16GB.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_16GB.py
new file mode 100644
index 0000000000000000000000000000000000000000..484cb46b9981e94cd23e193052a4e2ce4376eee4
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_16GB.py
@@ -0,0 +1,124 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_16GB(ExperimentPlanner3D_v21):
+ """
+ Same as ExperimentPlanner3D_v21, but designed to fill 16GB in fp16
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1_16GB"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_16GB_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ We need to adapt ref
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+ # use_this_for_batch_size_computation_3D = 520000000 # 505789440
+ # typical ExperimentPlanner3D_v21 configurations use ~8.5GB, but here we allow for 16GB
+ # to be used
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D * 16 / 8.5
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ # print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_32GB.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_32GB.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2759cedd978ae7d727ff89c910a34a4e6fad5eb
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_32GB.py
@@ -0,0 +1,122 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_32GB(ExperimentPlanner3D_v21):
+ """
+ Same as ExperimentPlanner3D_v21, but designed to fill a V100 (32GB) in fp16
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_32GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1_verybig"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_verybig_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ We need to adapt ref
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+ # use_this_for_batch_size_computation_3D = 520000000 # 505789440
+ # typical ExperimentPlanner3D_v21 configurations use ~8.5GB, but on a V100 we have 32. Allow for more space
+ # to be used
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D * 32 / 8 # 8 instead of 8.5 because YOLO
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ # print(new_shp)
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py
new file mode 100644
index 0000000000000000000000000000000000000000..44d57a1a1ff0905648b9c7ef964228eaa06f00f2
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v21_3convperstage.py
@@ -0,0 +1,40 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_3cps(ExperimentPlanner3D_v21):
+ """
+ have 3x conv-in-lrelu per resolution instead of 2 while remaining in the same memory budget
+
+ This only works with 3d fullres because we use the same data as ExperimentPlanner3D_v21. Lowres would require to
+ rerun preprocesing (different patch size = different 3d lowres target spacing)
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_3cps, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_3cps_plans_3D.pkl")
+ self.unet_base_num_features = 32
+ self.conv_per_stage = 3
+
+ def run_preprocessing(self, num_threads):
+ pass
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py
new file mode 100644
index 0000000000000000000000000000000000000000..63a702e5623d53d9cf7a5f995d8bbeb8d0f77fbd
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v22.py
@@ -0,0 +1,59 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v22(ExperimentPlanner3D_v21):
+ """
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super().__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.2"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.2_plans_3D.pkl")
+
+ def get_target_spacing(self):
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
+ target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
+ target_size_mm = np.array(target) * np.array(target_size)
+ # we need to identify datasets for which a different target spacing could be beneficial. These datasets have
+ # the following properties:
+ # - one axis which much lower resolution than the others
+ # - the lowres axis has much less voxels than the others
+ # - (the size in mm of the lowres axis is also reduced)
+ worst_spacing_axis = np.argmax(target)
+ other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
+ other_spacings = [target[i] for i in other_axes]
+ other_sizes = [target_size[i] for i in other_axes]
+
+ has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
+ has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
+ # we don't use the last one for now
+ #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
+
+ if has_aniso_spacing and has_aniso_voxels:
+ spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
+ target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
+ # don't let the spacing of that axis get higher than self.anisotropy_thresholdxthe_other_axes
+ target_spacing_of_that_axis = max(max(other_spacings) * self.anisotropy_threshold, target_spacing_of_that_axis)
+ target[worst_spacing_axis] = target_spacing_of_that_axis
+ return target
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py
new file mode 100644
index 0000000000000000000000000000000000000000..5854e843f6e5161d2f88189504a44f9b673a37e8
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_baseline_3DUNet_v23.py
@@ -0,0 +1,28 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v23(ExperimentPlanner3D_v21):
+ """
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v23, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.3"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.3_plans_3D.pkl")
+ self.preprocessor_name = "Preprocessor3DDifferentResampling"
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_pretrained.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_pretrained.py
new file mode 100644
index 0000000000000000000000000000000000000000..0150607e673a35e220e4e42c6243a1b785f10ec5
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_pretrained.py
@@ -0,0 +1,42 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.utilities.file_and_folder_operations import load_pickle
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_Pretrained(ExperimentPlanner3D_v21):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder, pretrained_model_plans_file: str,
+ pretrained_name: str):
+ super().__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.pretrained_model_plans_file = pretrained_model_plans_file
+ self.pretrained_name = pretrained_name
+ self.data_identifier = "nnUNetData_pretrained_" + pretrained_name
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans_pretrained_%s_plans_3D.pkl" % pretrained_name)
+
+ def load_pretrained_plans(self):
+ classes = self.plans['num_classes']
+ self.plans = load_pickle(self.pretrained_model_plans_file)
+ self.plans['num_classes'] = classes
+ self.transpose_forward = self.plans['transpose_forward']
+ self.preprocessor_name = self.plans['preprocessor_name']
+ self.plans_per_stage = self.plans['plans_per_stage']
+ self.plans['data_identifier'] = self.data_identifier
+ self.save_my_plans()
+ print(self.plans['plans_per_stage'])
+
+ def run_preprocessing(self, num_threads):
+ self.load_pretrained_plans()
+ super().run_preprocessing(num_threads)
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_residual_3DUNet_v21.py b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_residual_3DUNet_v21.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe01657141ce66e7eff0e61bd6a06ee0ccc35e5
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/experiment_planner_residual_3DUNet_v21.py
@@ -0,0 +1,132 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.paths import *
+from nnunet.network_architecture.generic_modular_residual_UNet import FabiansUNet
+
+
+class ExperimentPlanner3DFabiansResUNet_v21(ExperimentPlanner3D_v21):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3DFabiansResUNet_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1"# "nnUNetData_FabiansResUNet_v2.1"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlans_FabiansResUNet_v2.1_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ We use FabiansUNet instead of Generic_UNet
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+ pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes
+ blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)]
+ blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1]
+
+ ref = FabiansUNet.use_this_for_3D_configuration
+ here = FabiansUNet.compute_approx_vram_consumption(input_patch_size, self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities, num_classes,
+ pool_op_kernel_sizes, blocks_per_stage_encoder,
+ blocks_per_stage_decoder, 2, self.unet_min_batch_size,)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ pool_op_kernel_sizes = [[1, 1, 1]] + pool_op_kernel_sizes
+ blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder[:len(pool_op_kernel_sizes)]
+ blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder[:len(pool_op_kernel_sizes) - 1]
+ here = FabiansUNet.compute_approx_vram_consumption(new_shp, self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities, num_classes,
+ pool_op_kernel_sizes, blocks_per_stage_encoder,
+ blocks_per_stage_decoder, 2, self.unet_min_batch_size)
+ input_patch_size = new_shp
+
+ batch_size = FabiansUNet.default_min_batch_size
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ 'num_blocks_encoder': blocks_per_stage_encoder,
+ 'num_blocks_decoder': blocks_per_stage_decoder
+ }
+ return plan
+
+ def run_preprocessing(self, num_threads):
+ """
+ On all datasets except 3d fullres on spleen the preprocessed data would look identical to
+ ExperimentPlanner3D_v21 (I tested decathlon data only). Therefore we just reuse the preprocessed data of
+ that other planner
+ :param num_threads:
+ :return:
+ """
+ pass
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/normalization/__init__.py b/nnunet/experiment_planning/alternative_experiment_planning/normalization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c177fc688d31381572b0bcac2e067c988a3e19
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_2DUNet_v21_RGB_scaleto_0_1.py
@@ -0,0 +1,32 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.experiment_planning.experiment_planner_baseline_2DUNet_v21 import ExperimentPlanner2D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner2D_v21_RGB_scaleTo_0_1(ExperimentPlanner2D_v21):
+ """
+ used by tutorial nnunet.tutorials.custom_preprocessing
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super().__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNet_RGB_scaleTo_0_1"
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNet_RGB_scaleTo_0_1" + "_plans_2D.pkl")
+
+ # The custom preprocessor class we intend to use is GenericPreprocessor_scale_uint8_to_0_1. It must be located
+ # in nnunet.preprocessing (any file and submodule) and will be found by its name. Make sure to always define
+ # unique names!
+ self.preprocessor_name = 'GenericPreprocessor_scale_uint8_to_0_1'
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c24340dafdadafcb40993c6286e26d9e5be1f6e
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_CT2.py
@@ -0,0 +1,45 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.paths import *
+
+
+class ExperimentPlannerCT2(ExperimentPlanner):
+ """
+ preprocesses CT data with the "CT2" normalization.
+
+ (clip range comes from training set and is the 0.5 and 99.5 percentile of intensities in foreground)
+ CT = clip to range, then normalize with global mn and sd (computed on foreground in training set)
+ CT2 = clip to range, normalize each case separately with its own mn and std (computed within the area that was in clip_range)
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlannerCT2, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNet_CT2"
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "CT2_plans_3D.pkl")
+
+ def determine_normalization_scheme(self):
+ schemes = OrderedDict()
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ for i in range(num_modalities):
+ if modalities[i] == "CT":
+ schemes[i] = "CT2"
+ else:
+ schemes[i] = "nonCT"
+ return schemes
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7ca419eca6913160f7876971ad5220fcfba8299
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/normalization/experiment_planner_3DUNet_nonCT.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.paths import *
+
+
+class ExperimentPlannernonCT(ExperimentPlanner):
+ """
+ Preprocesses all data in nonCT mode (this is what we use for MRI per default, but here it is applied to CT images
+ as well)
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlannernonCT, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNet_nonCT"
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "nonCT_plans_3D.pkl")
+
+ def determine_normalization_scheme(self):
+ schemes = OrderedDict()
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ for i in range(num_modalities):
+ if modalities[i] == "CT":
+ schemes[i] = "nonCT"
+ else:
+ schemes[i] = "nonCT"
+ return schemes
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/patch_size/__init__.py b/nnunet/experiment_planning/alternative_experiment_planning/patch_size/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_mm.py b/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_mm.py
new file mode 100644
index 0000000000000000000000000000000000000000..1908ae8d465df909752bd96b02ec3da1e8347c8f
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_mm.py
@@ -0,0 +1,128 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlannerIso(ExperimentPlanner):
+ """
+ attempts to create patches that have an isotropic size (in mm, not voxels)
+
+ CAREFUL!
+ this one does not support transpose_forward and transpose_backward
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super().__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "fixedisoPatchesInmm_plans_3D.pkl")
+ self.data_identifier = "nnUNet_isoPatchesInmm"
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ # here is the difference to ExperimentPlanner. In the old version we made the aspect ratio match
+ # between patch and new_median_shape, regardless of spacing. It could be better to enforce isotropy
+ # (in mm) instead
+ current_patch_in_mm = new_shp * current_spacing
+ axis_to_be_reduced = np.argsort(current_patch_in_mm)[-1]
+
+ # from here on it's the same as before
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props_poolLateV2(tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_voxels.py b/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_voxels.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa041ab907021555ec6fda148f4d52a61bd7d9b7
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/patch_size/experiment_planner_3DUNet_isotropic_in_voxels.py
@@ -0,0 +1,115 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_IsoPatchesInVoxels(ExperimentPlanner):
+ """
+ patches that are isotropic in the number of voxels (not mm), such as 128x128x128 allow more voxels to be processed
+ at once because we don't have to do annoying pooling stuff
+
+ CAREFUL!
+ this one does not support transpose_forward and transpose_backward
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_IsoPatchesInVoxels, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_isoPatchesInVoxels"
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "fixedisoPatchesInVoxels_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ input_patch_size = new_median_shape
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ # find the largest axis. If patch is isotropic, pick the axis with the largest spacing
+ if len(np.unique(new_shp)) == 1:
+ axis_to_be_reduced = np.argsort(current_spacing)[-1]
+ else:
+ axis_to_be_reduced = np.argsort(new_shp)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props_poolLateV2(tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/__init__.py b/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_allConv3x3.py b/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_allConv3x3.py
new file mode 100644
index 0000000000000000000000000000000000000000..58386073a8cf49f1ade33fad15b72d84d7867e9c
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_allConv3x3.py
@@ -0,0 +1,139 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlannerAllConv3x3(ExperimentPlanner):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlannerAllConv3x3, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlans" + "allConv3x3_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is
+ opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that
+ for some organ of interest the acquisition method will most likely be chosen such that the field of view and
+ voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated
+ for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I
+ will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to
+ try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch)
+
+ The patches created here attempt keep the aspect ratio of the new_median_shape
+
+ :param current_spacing:
+ :param original_spacing:
+ :param original_shape:
+ :param num_cases:
+ :return:
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props_poolLateV2(tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ for s in range(len(conv_kernel_sizes)):
+ conv_kernel_sizes[s] = [3 for _ in conv_kernel_sizes[s]]
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
+ def run_preprocessing(self, num_threads):
+ pass
+
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py b/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py
new file mode 100644
index 0000000000000000000000000000000000000000..f99c137dd38f7766a212a656e0e1dcc1d8d33612
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/pooling_and_convs/experiment_planner_baseline_3DUNet_poolBasedOnSpacing.py
@@ -0,0 +1,124 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlannerPoolBasedOnSpacing(ExperimentPlanner):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlannerPoolBasedOnSpacing, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_poolBasedOnSpacing"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlans" + "poolBasedOnSpacing_plans_3D.pkl")
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is
+ (2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing.
+ This can cause a larger memory footprint, so it can be beneficial to revise this.
+
+ Here we are pooling based on the spacing of the data.
+
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/readme.md b/nnunet/experiment_planning/alternative_experiment_planning/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..b9c5c0e2a735727b538ffb0b4c7ac58f29b0c34e
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/readme.md
@@ -0,0 +1,2 @@
+These alternatives are not used in nnU-Net, but you can use them if you believe they might be better suited for you.
+I (Fabian) have not found them to be consistently superior.
\ No newline at end of file
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/__init__.py b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9ca27aa0869df6dcc4df1f3aea916e7d240a921
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_targetSpacingForAnisoAxis.py
@@ -0,0 +1,63 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.paths import *
+
+
+class ExperimentPlannerTargetSpacingForAnisoAxis(ExperimentPlanner):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super().__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_targetSpacingForAnisoAxis"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlans" + "targetSpacingForAnisoAxis_plans_3D.pkl")
+
+ def get_target_spacing(self):
+ """
+ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
+ and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
+
+ For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
+ (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a pacing of 5 or 6 mm in the low
+ resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
+ impact performance (due to the low number of slices).
+ """
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
+ target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
+ target_size_mm = np.array(target) * np.array(target_size)
+ # we need to identify datasets for which a different target spacing could be beneficial. These datasets have
+ # the following properties:
+ # - one axis which much lower resolution than the others
+ # - the lowres axis has much less voxels than the others
+ # - (the size in mm of the lowres axis is also reduced)
+ worst_spacing_axis = np.argmax(target)
+ other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
+ other_spacings = [target[i] for i in other_axes]
+ other_sizes = [target_size[i] for i in other_axes]
+
+ has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
+ has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < max(other_sizes)
+ # we don't use the last one for now
+ #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
+
+ if has_aniso_spacing and has_aniso_voxels:
+ spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
+ target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
+ target[worst_spacing_axis] = target_spacing_of_that_axis
+ return target
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7f8f7dff9943031a5888585b37f4f83bd9783fc
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_customTargetSpacing_2x2x2.py
@@ -0,0 +1,33 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import ExperimentPlanner3D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_customTargetSpacing_2x2x2(ExperimentPlanner3D_v21):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ # we change the data identifier and plans_fname. This will make this experiment planner save the preprocessed
+ # data in a different folder so that they can co-exist with the default (ExperimentPlanner3D_v21). We also
+ # create a custom plans file that will be linked to this data
+ self.data_identifier = "nnUNetData_plans_v2.1_trgSp_2x2x2"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_trgSp_2x2x2_plans_3D.pkl")
+
+ def get_target_spacing(self):
+ # simply return the desired spacing as np.array
+ return np.array([2., 2., 2.]) # make sure this is float!!!! Not int!
+
diff --git a/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_noResampling.py b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_noResampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..3810cc780762adb4c2608f455c2550c4a85260a5
--- /dev/null
+++ b/nnunet/experiment_planning/alternative_experiment_planning/target_spacing/experiment_planner_baseline_3DUNet_v21_noResampling.py
@@ -0,0 +1,216 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from nnunet.experiment_planning.alternative_experiment_planning.experiment_planner_baseline_3DUNet_v21_16GB import \
+ ExperimentPlanner3D_v21_16GB
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet_v21 import \
+ ExperimentPlanner3D_v21
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21_noResampling(ExperimentPlanner3D_v21):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_noResampling, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_noRes_plans_v2.1"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_noRes_plans_3D.pkl")
+ self.preprocessor_name = "PreprocessorFor3D_NoResampling"
+
+ def plan_experiment(self):
+ """
+ DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres
+ :return:
+ """
+ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
+ print("Are we using the nonzero mask for normalization?", use_nonzero_mask_for_normalization)
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ all_classes = self.dataset_properties['all_classes']
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ target_spacing = self.get_target_spacing()
+ new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
+
+ max_spacing_axis = np.argmax(target_spacing)
+ remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
+ self.transpose_forward = [max_spacing_axis] + remaining_axes
+ self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
+
+ # we base our calculations on the median shape of the datasets
+ median_shape = np.median(np.vstack(new_shapes), 0)
+ print("the median shape of the dataset is ", median_shape)
+
+ max_shape = np.max(np.vstack(new_shapes), 0)
+ print("the max shape in the dataset is ", max_shape)
+ min_shape = np.min(np.vstack(new_shapes), 0)
+ print("the min shape in the dataset is ", min_shape)
+
+ print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
+
+ # how many stages will the image pyramid have?
+ self.plans_per_stage = list()
+
+ target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
+ median_shape_transposed = np.array(median_shape)[self.transpose_forward]
+ print("the transposed median shape of the dataset is ", median_shape_transposed)
+
+ print("generating configuration for 3d_fullres")
+ self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
+ median_shape_transposed,
+ len(self.list_of_cropped_npz_files),
+ num_modalities, len(all_classes) + 1))
+
+ # thanks Zakiyi (https://github.com/MIC-DKFZ/nnUNet/issues/61) for spotting this bug :-)
+ # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
+ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
+ architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
+ if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
+ architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0:
+ more = False
+ else:
+ more = True
+
+ if more:
+ pass
+
+ self.plans_per_stage = self.plans_per_stage[::-1]
+ self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
+
+ print(self.plans_per_stage)
+ print("transpose forward", self.transpose_forward)
+ print("transpose backward", self.transpose_backward)
+
+ normalization_schemes = self.determine_normalization_scheme()
+ only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
+ # removed training data based postprocessing. This is deprecated
+
+ # these are independent of the stage
+ plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
+ 'modalities': modalities, 'normalization_schemes': normalization_schemes,
+ 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
+ 'original_spacings': spacings, 'original_sizes': sizes,
+ 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
+ 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
+ 'use_mask_for_norm': use_nonzero_mask_for_normalization,
+ 'keep_only_largest_region': only_keep_largest_connected_component,
+ 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
+ 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
+ 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
+ 'preprocessor_name': self.preprocessor_name,
+ 'conv_per_stage': self.conv_per_stage,
+ }
+
+ self.plans = plans
+ self.save_my_plans()
+
+
+class ExperimentPlanner3D_v21_noResampling_16GB(ExperimentPlanner3D_v21_16GB):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21_noResampling_16GB, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_noRes_plans_16GB_v2.1"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_noRes_16GB_plans_3D.pkl")
+ self.preprocessor_name = "PreprocessorFor3D_NoResampling"
+
+ def plan_experiment(self):
+ """
+ DIFFERENCE TO ExperimentPlanner3D_v21: no 3d lowres
+ :return:
+ """
+ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
+ print("Are we using the nonzero mask for normalization?", use_nonzero_mask_for_normalization)
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ all_classes = self.dataset_properties['all_classes']
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ target_spacing = self.get_target_spacing()
+ new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
+
+ max_spacing_axis = np.argmax(target_spacing)
+ remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
+ self.transpose_forward = [max_spacing_axis] + remaining_axes
+ self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
+
+ # we base our calculations on the median shape of the datasets
+ median_shape = np.median(np.vstack(new_shapes), 0)
+ print("the median shape of the dataset is ", median_shape)
+
+ max_shape = np.max(np.vstack(new_shapes), 0)
+ print("the max shape in the dataset is ", max_shape)
+ min_shape = np.min(np.vstack(new_shapes), 0)
+ print("the min shape in the dataset is ", min_shape)
+
+ print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
+
+ # how many stages will the image pyramid have?
+ self.plans_per_stage = list()
+
+ target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
+ median_shape_transposed = np.array(median_shape)[self.transpose_forward]
+ print("the transposed median shape of the dataset is ", median_shape_transposed)
+
+ print("generating configuration for 3d_fullres")
+ self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
+ median_shape_transposed,
+ len(self.list_of_cropped_npz_files),
+ num_modalities, len(all_classes) + 1))
+
+ # thanks Zakiyi (https://github.com/MIC-DKFZ/nnUNet/issues/61) for spotting this bug :-)
+ # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
+ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
+ architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
+ if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
+ architecture_input_voxels_here < self.how_much_of_a_patient_must_the_network_see_at_stage0:
+ more = False
+ else:
+ more = True
+
+ if more:
+ pass
+
+ self.plans_per_stage = self.plans_per_stage[::-1]
+ self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
+
+ print(self.plans_per_stage)
+ print("transpose forward", self.transpose_forward)
+ print("transpose backward", self.transpose_backward)
+
+ normalization_schemes = self.determine_normalization_scheme()
+ only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
+ # removed training data based postprocessing. This is deprecated
+
+ # these are independent of the stage
+ plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
+ 'modalities': modalities, 'normalization_schemes': normalization_schemes,
+ 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
+ 'original_spacings': spacings, 'original_sizes': sizes,
+ 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
+ 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
+ 'use_mask_for_norm': use_nonzero_mask_for_normalization,
+ 'keep_only_largest_region': only_keep_largest_connected_component,
+ 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
+ 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
+ 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
+ 'preprocessor_name': self.preprocessor_name,
+ 'conv_per_stage': self.conv_per_stage,
+ }
+
+ self.plans = plans
+ self.save_my_plans()
\ No newline at end of file
diff --git a/nnunet/experiment_planning/change_batch_size.py b/nnunet/experiment_planning/change_batch_size.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b8b735e8cf0ff8658659edd82af8cf7223e5e47
--- /dev/null
+++ b/nnunet/experiment_planning/change_batch_size.py
@@ -0,0 +1,9 @@
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+
+if __name__ == '__main__':
+ input_file = '/home/fabian/data/nnUNet_preprocessed/Task004_Hippocampus/nnUNetPlansv2.1_plans_3D.pkl'
+ output_file = '/home/fabian/data/nnUNet_preprocessed/Task004_Hippocampus/nnUNetPlansv2.1_LISA_plans_3D.pkl'
+ a = load_pickle(input_file)
+ a['plans_per_stage'][0]['batch_size'] = int(np.floor(6 / 9 * a['plans_per_stage'][0]['batch_size']))
+ save_pickle(a, output_file)
\ No newline at end of file
diff --git a/nnunet/experiment_planning/common_utils.py b/nnunet/experiment_planning/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..78f40bc646f7d715013245cc3b07aa024f8c65e2
--- /dev/null
+++ b/nnunet/experiment_planning/common_utils.py
@@ -0,0 +1,267 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from copy import deepcopy
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+import SimpleITK as sitk
+import shutil
+from batchgenerators.utilities.file_and_folder_operations import join
+
+
+def split_4d_nifti(filename, output_folder, add_zeros=False):
+ img_itk = sitk.ReadImage(filename)
+ dim = img_itk.GetDimension()
+ file_base = filename.split("/")[-1]
+ if dim == 3:
+ shutil.copy(filename, join(output_folder, file_base[:-7] + "_0000.nii.gz"))
+ return
+ elif dim != 4:
+ raise RuntimeError("Unexpected dimensionality: %d of file %s, cannot split" % (dim, filename))
+ else:
+ img_npy = sitk.GetArrayFromImage(img_itk)
+ spacing = img_itk.GetSpacing()
+ origin = img_itk.GetOrigin()
+ direction = np.array(img_itk.GetDirection()).reshape(4,4)
+ # now modify these to remove the fourth dimension
+ spacing = tuple(list(spacing[:-1]))
+ origin = tuple(list(origin[:-1]))
+ direction = tuple(direction[:-1, :-1].reshape(-1))
+ for i, t in enumerate(range(img_npy.shape[0])):
+ img = img_npy[t]
+ img_itk_new = sitk.GetImageFromArray(img)
+ img_itk_new.SetSpacing(spacing)
+ img_itk_new.SetOrigin(origin)
+ img_itk_new.SetDirection(direction)
+ sitk.WriteImage(img_itk_new, join(output_folder, file_base[:-7] + "_%04.0d.nii.gz" % i))
+
+
+def get_pool_and_conv_props_poolLateV2(patch_size, min_feature_map_size, max_numpool, spacing):
+ """
+
+ :param spacing:
+ :param patch_size:
+ :param min_feature_map_size: min edge length of feature maps in bottleneck
+ :return:
+ """
+ initial_spacing = deepcopy(spacing)
+ reach = max(initial_spacing)
+ dim = len(patch_size)
+
+ num_pool_per_axis = get_network_numpool(patch_size, max_numpool, min_feature_map_size)
+
+ net_num_pool_op_kernel_sizes = []
+ net_conv_kernel_sizes = []
+ net_numpool = max(num_pool_per_axis)
+
+ current_spacing = spacing
+ for p in range(net_numpool):
+ reached = [current_spacing[i] / reach > 0.5 for i in range(dim)]
+ pool = [2 if num_pool_per_axis[i] + p >= net_numpool else 1 for i in range(dim)]
+ if all(reached):
+ conv = [3] * dim
+ else:
+ conv = [3 if not reached[i] else 1 for i in range(dim)]
+ net_num_pool_op_kernel_sizes.append(pool)
+ net_conv_kernel_sizes.append(conv)
+ current_spacing = [i * j for i, j in zip(current_spacing, pool)]
+
+ net_conv_kernel_sizes.append([3] * dim)
+
+ must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
+ patch_size = pad_shape(patch_size, must_be_divisible_by)
+
+ # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
+ return num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by
+
+
+def get_pool_and_conv_props(spacing, patch_size, min_feature_map_size, max_numpool):
+ """
+
+ :param spacing:
+ :param patch_size:
+ :param min_feature_map_size: min edge length of feature maps in bottleneck
+ :return:
+ """
+ dim = len(spacing)
+
+ current_spacing = deepcopy(list(spacing))
+ current_size = deepcopy(list(patch_size))
+
+ pool_op_kernel_sizes = []
+ conv_kernel_sizes = []
+
+ num_pool_per_axis = [0] * dim
+
+ while True:
+ # This is a problem because sometimes we have spacing 20, 50, 50 and we want to still keep pooling.
+ # Here we would stop however. This is not what we want! Fixed in get_pool_and_conv_propsv2
+ min_spacing = min(current_spacing)
+ valid_axes_for_pool = [i for i in range(dim) if current_spacing[i] / min_spacing < 2]
+ axes = []
+ for a in range(dim):
+ my_spacing = current_spacing[a]
+ partners = [i for i in range(dim) if current_spacing[i] / my_spacing < 2 and my_spacing / current_spacing[i] < 2]
+ if len(partners) > len(axes):
+ axes = partners
+ conv_kernel_size = [3 if i in axes else 1 for i in range(dim)]
+
+ # exclude axes that we cannot pool further because of min_feature_map_size constraint
+ #before = len(valid_axes_for_pool)
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if current_size[i] >= 2*min_feature_map_size]
+ #after = len(valid_axes_for_pool)
+ #if after == 1 and before > 1:
+ # break
+
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
+
+ if len(valid_axes_for_pool) == 0:
+ break
+
+ #print(current_spacing, current_size)
+
+ other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
+
+ pool_kernel_sizes = [0] * dim
+ for v in valid_axes_for_pool:
+ pool_kernel_sizes[v] = 2
+ num_pool_per_axis[v] += 1
+ current_spacing[v] *= 2
+ current_size[v] = np.ceil(current_size[v] / 2)
+ for nv in other_axes:
+ pool_kernel_sizes[nv] = 1
+
+ pool_op_kernel_sizes.append(pool_kernel_sizes)
+ conv_kernel_sizes.append(conv_kernel_size)
+ #print(conv_kernel_sizes)
+
+ must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
+ patch_size = pad_shape(patch_size, must_be_divisible_by)
+
+ # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
+ conv_kernel_sizes.append([3]*dim)
+ return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
+
+
+def get_pool_and_conv_props_v2(spacing, patch_size, min_feature_map_size, max_numpool):
+ """
+
+ :param spacing:
+ :param patch_size:
+ :param min_feature_map_size: min edge length of feature maps in bottleneck
+ :return:
+ """
+ dim = len(spacing)
+
+ current_spacing = deepcopy(list(spacing))
+ current_size = deepcopy(list(patch_size))
+
+ pool_op_kernel_sizes = []
+ conv_kernel_sizes = []
+
+ num_pool_per_axis = [0] * dim
+ kernel_size = [1] * dim
+
+ while True:
+ # exclude axes that we cannot pool further because of min_feature_map_size constraint
+ valid_axes_for_pool = [i for i in range(dim) if current_size[i] >= 2*min_feature_map_size]
+ if len(valid_axes_for_pool) < 1:
+ break
+
+ spacings_of_axes = [current_spacing[i] for i in valid_axes_for_pool]
+
+ # find axis that are within factor of 2 within smallest spacing
+ min_spacing_of_valid = min(spacings_of_axes)
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if current_spacing[i] / min_spacing_of_valid < 2]
+
+ # max_numpool constraint
+ valid_axes_for_pool = [i for i in valid_axes_for_pool if num_pool_per_axis[i] < max_numpool]
+
+ if len(valid_axes_for_pool) == 1:
+ if current_size[valid_axes_for_pool[0]] >= 3 * min_feature_map_size:
+ pass
+ else:
+ break
+ if len(valid_axes_for_pool) < 1:
+ break
+
+ # now we need to find kernel sizes
+ # kernel sizes are initialized to 1. They are successively set to 3 when their associated axis becomes within
+ # factor 2 of min_spacing. Once they are 3 they remain 3
+ for d in range(dim):
+ if kernel_size[d] == 3:
+ continue
+ else:
+ if spacings_of_axes[d] / min(current_spacing) < 2:
+ kernel_size[d] = 3
+
+ other_axes = [i for i in range(dim) if i not in valid_axes_for_pool]
+
+ pool_kernel_sizes = [0] * dim
+ for v in valid_axes_for_pool:
+ pool_kernel_sizes[v] = 2
+ num_pool_per_axis[v] += 1
+ current_spacing[v] *= 2
+ current_size[v] = np.ceil(current_size[v] / 2)
+ for nv in other_axes:
+ pool_kernel_sizes[nv] = 1
+
+ pool_op_kernel_sizes.append(pool_kernel_sizes)
+ conv_kernel_sizes.append(deepcopy(kernel_size))
+ #print(conv_kernel_sizes)
+
+ must_be_divisible_by = get_shape_must_be_divisible_by(num_pool_per_axis)
+ patch_size = pad_shape(patch_size, must_be_divisible_by)
+
+ # we need to add one more conv_kernel_size for the bottleneck. We always use 3x3(x3) conv here
+ conv_kernel_sizes.append([3]*dim)
+ return num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, patch_size, must_be_divisible_by
+
+
+def get_shape_must_be_divisible_by(net_numpool_per_axis):
+ return 2 ** np.array(net_numpool_per_axis)
+
+
+def pad_shape(shape, must_be_divisible_by):
+ """
+ pads shape so that it is divisibly by must_be_divisible_by
+ :param shape:
+ :param must_be_divisible_by:
+ :return:
+ """
+ if not isinstance(must_be_divisible_by, (tuple, list, np.ndarray)):
+ must_be_divisible_by = [must_be_divisible_by] * len(shape)
+ else:
+ assert len(must_be_divisible_by) == len(shape)
+
+ new_shp = [shape[i] + must_be_divisible_by[i] - shape[i] % must_be_divisible_by[i] for i in range(len(shape))]
+
+ for i in range(len(shape)):
+ if shape[i] % must_be_divisible_by[i] == 0:
+ new_shp[i] -= must_be_divisible_by[i]
+ new_shp = np.array(new_shp).astype(int)
+ return new_shp
+
+
+def get_network_numpool(patch_size, maxpool_cap=999, min_feature_map_size=4):
+ network_numpool_per_axis = np.floor([np.log(i / min_feature_map_size) / np.log(2) for i in patch_size]).astype(int)
+ network_numpool_per_axis = [min(i, maxpool_cap) for i in network_numpool_per_axis]
+ return network_numpool_per_axis
+
+
+if __name__ == '__main__':
+ # trying to fix https://github.com/MIC-DKFZ/nnUNet/issues/261
+ median_shape = [24, 504, 512]
+ spacing = [5.9999094, 0.50781202, 0.50781202]
+ num_pool_per_axis, net_num_pool_op_kernel_sizes, net_conv_kernel_sizes, patch_size, must_be_divisible_by = get_pool_and_conv_props_poolLateV2(median_shape, min_feature_map_size=4, max_numpool=999, spacing=spacing)
diff --git a/nnunet/experiment_planning/experiment_planner_baseline_2DUNet.py b/nnunet/experiment_planning/experiment_planner_baseline_2DUNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3eaa0520848d12ad18e885ba437ffcbae3c685ec
--- /dev/null
+++ b/nnunet/experiment_planning/experiment_planner_baseline_2DUNet.py
@@ -0,0 +1,158 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+
+import nnunet
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import load_pickle, subfiles
+from multiprocessing.pool import Pool
+from nnunet.configuration import default_num_threads
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.experiment_planning.utils import add_classes_in_slice_info
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+from nnunet.preprocessing.preprocessing import PreprocessorFor2D
+from nnunet.training.model_restore import recursive_find_python_class
+
+
+class ExperimentPlanner2D(ExperimentPlanner):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner2D, self).__init__(folder_with_cropped_data,
+ preprocessed_output_folder)
+ self.data_identifier = default_data_identifier + "_2D"
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "_plans_2D.pkl")
+
+ self.unet_base_num_features = 30
+ self.unet_max_num_filters = 512
+ self.unet_max_numpool = 999
+
+ self.preprocessor_name = "PreprocessorFor2D"
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+
+ dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases
+ input_patch_size = new_median_shape[1:]
+
+ network_numpool, net_pool_kernel_sizes, net_conv_kernel_sizes, input_patch_size, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+
+ estimated_gpu_ram_consumption = Generic_UNet.compute_approx_vram_consumption(input_patch_size,
+ network_numpool,
+ self.unet_base_num_features,
+ self.unet_max_num_filters,
+ num_modalities, num_classes,
+ net_pool_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+
+ batch_size = int(np.floor(Generic_UNet.use_this_for_batch_size_computation_2D /
+ estimated_gpu_ram_consumption * Generic_UNet.DEFAULT_BATCH_SIZE_2D))
+ if batch_size < self.unet_min_batch_size:
+ raise RuntimeError("This framework is not made to process patches this large. We will add patch-based "
+ "2D networks later. Sorry for the inconvenience")
+
+ # check if batch size is too large (more than 5 % of dataset)
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_numpool,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'pool_op_kernel_sizes': net_pool_kernel_sizes,
+ 'conv_kernel_sizes': net_conv_kernel_sizes,
+ 'do_dummy_2D_data_aug': False
+ }
+ return plan
+
+ def plan_experiment(self):
+ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
+ print("Are we using the nonzero mask for normalization?", use_nonzero_mask_for_normalization)
+
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+ all_classes = self.dataset_properties['all_classes']
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ target_spacing = self.get_target_spacing()
+ new_shapes = np.array([np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)])
+
+ max_spacing_axis = np.argmax(target_spacing)
+ remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
+ self.transpose_forward = [max_spacing_axis] + remaining_axes
+ self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
+
+ # we base our calculations on the median shape of the datasets
+ median_shape = np.median(np.vstack(new_shapes), 0)
+ print("the median shape of the dataset is ", median_shape)
+
+ max_shape = np.max(np.vstack(new_shapes), 0)
+ print("the max shape in the dataset is ", max_shape)
+ min_shape = np.min(np.vstack(new_shapes), 0)
+ print("the min shape in the dataset is ", min_shape)
+
+ print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
+
+ # how many stages will the image pyramid have?
+ self.plans_per_stage = []
+
+ target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
+ median_shape_transposed = np.array(median_shape)[self.transpose_forward]
+ print("the transposed median shape of the dataset is ", median_shape_transposed)
+
+ self.plans_per_stage.append(
+ self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed, median_shape_transposed,
+ num_cases=len(self.list_of_cropped_npz_files),
+ num_modalities=num_modalities,
+ num_classes=len(all_classes) + 1),
+ )
+
+ print(self.plans_per_stage)
+
+ self.plans_per_stage = self.plans_per_stage[::-1]
+ self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
+
+ normalization_schemes = self.determine_normalization_scheme()
+ # deprecated
+ only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
+
+ # these are independent of the stage
+ plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
+ 'modalities': modalities, 'normalization_schemes': normalization_schemes,
+ 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
+ 'original_spacings': spacings, 'original_sizes': sizes,
+ 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
+ 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
+ 'use_mask_for_norm': use_nonzero_mask_for_normalization,
+ 'keep_only_largest_region': only_keep_largest_connected_component,
+ 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
+ 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
+ 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
+ 'preprocessor_name': self.preprocessor_name,
+ }
+
+ self.plans = plans
+ self.save_my_plans()
diff --git a/nnunet/experiment_planning/experiment_planner_baseline_2DUNet_v21.py b/nnunet/experiment_planning/experiment_planner_baseline_2DUNet_v21.py
new file mode 100644
index 0000000000000000000000000000000000000000..146d436b1920244fdf02c8af80405dfb38781a32
--- /dev/null
+++ b/nnunet/experiment_planning/experiment_planner_baseline_2DUNet_v21.py
@@ -0,0 +1,100 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from copy import deepcopy
+
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+import numpy as np
+
+
+class ExperimentPlanner2D_v21(ExperimentPlanner2D):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner2D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1_2D"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_plans_2D.pkl")
+ self.unet_base_num_features = 32
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+
+ dataset_num_voxels = np.prod(new_median_shape, dtype=np.int64) * num_cases
+ input_patch_size = new_median_shape[1:]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+
+ # we pretend to use 30 feature maps. This will yield the same configuration as in V1. The larger memory
+ # footpring of 32 vs 30 is mor ethan offset by the fp16 training. We make fp16 training default
+ # Reason for 32 vs 30 feature maps is that 32 is faster in fp16 training (because multiple of 8)
+ ref = Generic_UNet.use_this_for_batch_size_computation_2D * Generic_UNet.DEFAULT_BATCH_SIZE_2D / 2 # for batch size 2
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp,
+ network_num_pool_per_axis,
+ 30,
+ self.unet_max_num_filters,
+ num_modalities, num_classes,
+ pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape[1:])[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing[1:], tmp, self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing[1:], new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ # print(new_shp)
+
+ batch_size = int(np.floor(ref / here) * 2)
+ input_patch_size = new_shp
+
+ if batch_size < self.unet_min_batch_size:
+ raise RuntimeError("This should not happen")
+
+ # check if batch size is too large (more than 5 % of dataset)
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ 'do_dummy_2D_data_aug': False
+ }
+ return plan
diff --git a/nnunet/experiment_planning/experiment_planner_baseline_3DUNet.py b/nnunet/experiment_planning/experiment_planner_baseline_3DUNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a8b65dc30e91df30dfcd055ecbac898257f891b
--- /dev/null
+++ b/nnunet/experiment_planning/experiment_planner_baseline_3DUNet.py
@@ -0,0 +1,494 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+from collections import OrderedDict
+from copy import deepcopy
+
+import nnunet
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.configuration import default_num_threads
+from nnunet.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props_poolLateV2
+from nnunet.experiment_planning.utils import create_lists_from_splitted_dataset
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+from nnunet.preprocessing.cropping import get_case_identifier_from_npz
+from nnunet.training.model_restore import recursive_find_python_class
+
+
+class ExperimentPlanner(object):
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ self.folder_with_cropped_data = folder_with_cropped_data
+ self.preprocessed_output_folder = preprocessed_output_folder
+ self.list_of_cropped_npz_files = subfiles(self.folder_with_cropped_data, True, None, ".npz", True)
+
+ self.preprocessor_name = "GenericPreprocessor"
+
+ assert isfile(join(self.folder_with_cropped_data, "dataset_properties.pkl")), \
+ "folder_with_cropped_data must contain dataset_properties.pkl"
+ self.dataset_properties = load_pickle(join(self.folder_with_cropped_data, "dataset_properties.pkl"))
+
+ self.plans_per_stage = OrderedDict()
+ self.plans = OrderedDict()
+ self.plans_fname = join(self.preprocessed_output_folder, "nnUNetPlans" + "fixed_plans_3D.pkl")
+ self.data_identifier = default_data_identifier
+
+ self.transpose_forward = [0, 1, 2]
+ self.transpose_backward = [0, 1, 2]
+
+ self.unet_base_num_features = Generic_UNet.BASE_NUM_FEATURES_3D
+ self.unet_max_num_filters = 320
+ self.unet_max_numpool = 999
+ self.unet_min_batch_size = 2
+ self.unet_featuremap_min_edge_length = 4
+
+ self.target_spacing_percentile = 50
+ self.anisotropy_threshold = 3
+ self.how_much_of_a_patient_must_the_network_see_at_stage0 = 4 # 1/4 of a patient
+ self.batch_size_covers_max_percent_of_dataset = 0.05 # all samples in the batch together cannot cover more
+ # than 5% of the entire dataset
+
+ self.conv_per_stage = 2
+
+ def get_target_spacing(self):
+ spacings = self.dataset_properties['all_spacings']
+
+ # target = np.median(np.vstack(spacings), 0)
+ # if target spacing is very anisotropic we may want to not downsample the axis with the worst spacing
+ # uncomment after mystery task submission
+ """worst_spacing_axis = np.argmax(target)
+ if max(target) > (2.5 * min(target)):
+ spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
+ target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 5)
+ target[worst_spacing_axis] = target_spacing_of_that_axis"""
+
+ target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
+ return target
+
+ def save_my_plans(self):
+ with open(self.plans_fname, 'wb') as f:
+ pickle.dump(self.plans, f)
+
+ def load_my_plans(self):
+ self.plans = load_pickle(self.plans_fname)
+
+ self.plans_per_stage = self.plans['plans_per_stage']
+ self.dataset_properties = self.plans['dataset_properties']
+
+ self.transpose_forward = self.plans['transpose_forward']
+ self.transpose_backward = self.plans['transpose_backward']
+
+ def determine_postprocessing(self):
+ pass
+ """
+ Spoiler: This is unused, postprocessing was removed. Ignore it.
+ :return:
+ print("determining postprocessing...")
+
+ props_per_patient = self.dataset_properties['segmentation_props_per_patient']
+
+ all_region_keys = [i for k in props_per_patient.keys() for i in props_per_patient[k]['only_one_region'].keys()]
+ all_region_keys = list(set(all_region_keys))
+
+ only_keep_largest_connected_component = OrderedDict()
+
+ for r in all_region_keys:
+ all_results = [props_per_patient[k]['only_one_region'][r] for k in props_per_patient.keys()]
+ only_keep_largest_connected_component[tuple(r)] = all(all_results)
+
+ print("Postprocessing: only_keep_largest_connected_component", only_keep_largest_connected_component)
+
+ all_classes = self.dataset_properties['all_classes']
+ classes = [i for i in all_classes if i > 0]
+
+ props_per_patient = self.dataset_properties['segmentation_props_per_patient']
+
+ min_size_per_class = OrderedDict()
+ for c in classes:
+ all_num_voxels = []
+ for k in props_per_patient.keys():
+ all_num_voxels.append(props_per_patient[k]['volume_per_class'][c])
+ if len(all_num_voxels) > 0:
+ min_size_per_class[c] = np.percentile(all_num_voxels, 1) * MIN_SIZE_PER_CLASS_FACTOR
+ else:
+ min_size_per_class[c] = np.inf
+
+ min_region_size_per_class = OrderedDict()
+ for c in classes:
+ region_sizes = [l for k in props_per_patient for l in props_per_patient[k]['region_volume_per_class'][c]]
+ if len(region_sizes) > 0:
+ min_region_size_per_class[c] = min(region_sizes)
+ # we don't need that line but better safe than sorry, right?
+ min_region_size_per_class[c] = min(min_region_size_per_class[c], min_size_per_class[c])
+ else:
+ min_region_size_per_class[c] = 0
+
+ print("Postprocessing: min_size_per_class", min_size_per_class)
+ print("Postprocessing: min_region_size_per_class", min_region_size_per_class)
+ return only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class
+ """
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ Computation of input patch size starts out with the new median shape (in voxels) of a dataset. This is
+ opposed to prior experiments where I based it on the median size in mm. The rationale behind this is that
+ for some organ of interest the acquisition method will most likely be chosen such that the field of view and
+ voxel resolution go hand in hand to show the doctor what they need to see. This assumption may be violated
+ for some modalities with anisotropy (cine MRI) but we will have t live with that. In future experiments I
+ will try to 1) base input patch size match aspect ratio of input size in mm (instead of voxels) and 2) to
+ try to enforce that we see the same 'distance' in all directions (try to maintain equal size in mm of patch)
+
+ The patches created here attempt keep the aspect ratio of the new_median_shape
+
+ :param current_spacing:
+ :param original_spacing:
+ :param original_shape:
+ :param num_cases:
+ :return:
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props_poolLateV2(tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props_poolLateV2(new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ current_spacing)
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ # print(new_shp)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what works with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
+
+ def plan_experiment(self):
+ use_nonzero_mask_for_normalization = self.determine_whether_to_use_mask_for_norm()
+ print("Are we using the nonzero mask for normalization?", use_nonzero_mask_for_normalization)
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ all_classes = self.dataset_properties['all_classes']
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ target_spacing = self.get_target_spacing()
+ new_shapes = [np.array(i) / target_spacing * np.array(j) for i, j in zip(spacings, sizes)]
+
+ max_spacing_axis = np.argmax(target_spacing)
+ remaining_axes = [i for i in list(range(3)) if i != max_spacing_axis]
+ self.transpose_forward = [max_spacing_axis] + remaining_axes
+ self.transpose_backward = [np.argwhere(np.array(self.transpose_forward) == i)[0][0] for i in range(3)]
+
+ # we base our calculations on the median shape of the datasets
+ median_shape = np.median(np.vstack(new_shapes), 0)
+ print("the median shape of the dataset is ", median_shape)
+
+ max_shape = np.max(np.vstack(new_shapes), 0)
+ print("the max shape in the dataset is ", max_shape)
+ min_shape = np.min(np.vstack(new_shapes), 0)
+ print("the min shape in the dataset is ", min_shape)
+
+ print("we don't want feature maps smaller than ", self.unet_featuremap_min_edge_length, " in the bottleneck")
+
+ # how many stages will the image pyramid have?
+ self.plans_per_stage = list()
+
+ target_spacing_transposed = np.array(target_spacing)[self.transpose_forward]
+ median_shape_transposed = np.array(median_shape)[self.transpose_forward]
+ print("the transposed median shape of the dataset is ", median_shape_transposed)
+
+ print("generating configuration for 3d_fullres")
+ self.plans_per_stage.append(self.get_properties_for_stage(target_spacing_transposed, target_spacing_transposed,
+ median_shape_transposed,
+ len(self.list_of_cropped_npz_files),
+ num_modalities, len(all_classes) + 1))
+
+ # thanks Zakiyi (https://github.com/MIC-DKFZ/nnUNet/issues/61) for spotting this bug :-)
+ # if np.prod(self.plans_per_stage[-1]['median_patient_size_in_voxels'], dtype=np.int64) / \
+ # architecture_input_voxels < HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0:
+ architecture_input_voxels_here = np.prod(self.plans_per_stage[-1]['patch_size'], dtype=np.int64)
+ if np.prod(median_shape) / architecture_input_voxels_here < \
+ self.how_much_of_a_patient_must_the_network_see_at_stage0:
+ more = False
+ else:
+ more = True
+
+ if more:
+ print("generating configuration for 3d_lowres")
+ # if we are doing more than one stage then we want the lowest stage to have exactly
+ # HOW_MUCH_OF_A_PATIENT_MUST_THE_NETWORK_SEE_AT_STAGE0 (this is 4 by default so the number of voxels in the
+ # median shape of the lowest stage must be 4 times as much as the network can process at once (128x128x128 by
+ # default). Problem is that we are downsampling higher resolution axes before we start downsampling the
+ # out-of-plane axis. We could probably/maybe do this analytically but I am lazy, so here
+ # we do it the dumb way
+
+ lowres_stage_spacing = deepcopy(target_spacing)
+ num_voxels = np.prod(median_shape, dtype=np.float64)
+ while num_voxels > self.how_much_of_a_patient_must_the_network_see_at_stage0 * architecture_input_voxels_here:
+ max_spacing = max(lowres_stage_spacing)
+ if np.any((max_spacing / lowres_stage_spacing) > 2):
+ lowres_stage_spacing[(max_spacing / lowres_stage_spacing) > 2] \
+ *= 1.01
+ else:
+ lowres_stage_spacing *= 1.01
+ num_voxels = np.prod(target_spacing / lowres_stage_spacing * median_shape, dtype=np.float64)
+
+ lowres_stage_spacing_transposed = np.array(lowres_stage_spacing)[self.transpose_forward]
+ new = self.get_properties_for_stage(lowres_stage_spacing_transposed, target_spacing_transposed,
+ median_shape_transposed,
+ len(self.list_of_cropped_npz_files),
+ num_modalities, len(all_classes) + 1)
+ architecture_input_voxels_here = np.prod(new['patch_size'], dtype=np.int64)
+ if 2 * np.prod(new['median_patient_size_in_voxels'], dtype=np.int64) < np.prod(
+ self.plans_per_stage[0]['median_patient_size_in_voxels'], dtype=np.int64):
+ self.plans_per_stage.append(new)
+
+ self.plans_per_stage = self.plans_per_stage[::-1]
+ self.plans_per_stage = {i: self.plans_per_stage[i] for i in range(len(self.plans_per_stage))} # convert to dict
+
+ print(self.plans_per_stage)
+ print("transpose forward", self.transpose_forward)
+ print("transpose backward", self.transpose_backward)
+
+ normalization_schemes = self.determine_normalization_scheme()
+ only_keep_largest_connected_component, min_size_per_class, min_region_size_per_class = None, None, None
+ # removed training data based postprocessing. This is deprecated
+
+ # these are independent of the stage
+ plans = {'num_stages': len(list(self.plans_per_stage.keys())), 'num_modalities': num_modalities,
+ 'modalities': modalities, 'normalization_schemes': normalization_schemes,
+ 'dataset_properties': self.dataset_properties, 'list_of_npz_files': self.list_of_cropped_npz_files,
+ 'original_spacings': spacings, 'original_sizes': sizes,
+ 'preprocessed_data_folder': self.preprocessed_output_folder, 'num_classes': len(all_classes),
+ 'all_classes': all_classes, 'base_num_features': self.unet_base_num_features,
+ 'use_mask_for_norm': use_nonzero_mask_for_normalization,
+ 'keep_only_largest_region': only_keep_largest_connected_component,
+ 'min_region_size_per_class': min_region_size_per_class, 'min_size_per_class': min_size_per_class,
+ 'transpose_forward': self.transpose_forward, 'transpose_backward': self.transpose_backward,
+ 'data_identifier': self.data_identifier, 'plans_per_stage': self.plans_per_stage,
+ 'preprocessor_name': self.preprocessor_name,
+ 'conv_per_stage': self.conv_per_stage,
+ }
+
+ self.plans = plans
+ self.save_my_plans()
+
+ def determine_normalization_scheme(self):
+ schemes = OrderedDict()
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+
+ for i in range(num_modalities):
+ if modalities[i] == "CT" or modalities[i] == 'ct':
+ schemes[i] = "CT"
+ elif modalities[i] == 'noNorm':
+ schemes[i] = "noNorm"
+ else:
+ schemes[i] = "nonCT"
+ return schemes
+
+ def save_properties_of_cropped(self, case_identifier, properties):
+ with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'wb') as f:
+ pickle.dump(properties, f)
+
+ def load_properties_of_cropped(self, case_identifier):
+ with open(join(self.folder_with_cropped_data, "%s.pkl" % case_identifier), 'rb') as f:
+ properties = pickle.load(f)
+ return properties
+
+ def determine_whether_to_use_mask_for_norm(self):
+ # only use the nonzero mask for normalization of the cropping based on it resulted in a decrease in
+ # image size (this is an indication that the data is something like brats/isles and then we want to
+ # normalize in the brain region only)
+ modalities = self.dataset_properties['modalities']
+ num_modalities = len(list(modalities.keys()))
+ use_nonzero_mask_for_norm = OrderedDict()
+
+ for i in range(num_modalities):
+ if "CT" in modalities[i]:
+ use_nonzero_mask_for_norm[i] = False
+ else:
+ all_size_reductions = []
+ for k in self.dataset_properties['size_reductions'].keys():
+ all_size_reductions.append(self.dataset_properties['size_reductions'][k])
+
+ if np.median(all_size_reductions) < 3 / 4.:
+ print("using nonzero mask for normalization")
+ use_nonzero_mask_for_norm[i] = True
+ else:
+ print("not using nonzero mask for normalization")
+ use_nonzero_mask_for_norm[i] = False
+
+ for c in self.list_of_cropped_npz_files:
+ case_identifier = get_case_identifier_from_npz(c)
+ properties = self.load_properties_of_cropped(case_identifier)
+ properties['use_nonzero_mask_for_norm'] = use_nonzero_mask_for_norm
+ self.save_properties_of_cropped(case_identifier, properties)
+ use_nonzero_mask_for_normalization = use_nonzero_mask_for_norm
+ return use_nonzero_mask_for_normalization
+
+ def write_normalization_scheme_to_patients(self):
+ """
+ This is used for test set preprocessing
+ :return:
+ """
+ for c in self.list_of_cropped_npz_files:
+ case_identifier = get_case_identifier_from_npz(c)
+ properties = self.load_properties_of_cropped(case_identifier)
+ properties['use_nonzero_mask_for_norm'] = self.plans['use_mask_for_norm']
+ self.save_properties_of_cropped(case_identifier, properties)
+
+ def run_preprocessing(self, num_threads):
+ if os.path.isdir(join(self.preprocessed_output_folder, "gt_segmentations")):
+ shutil.rmtree(join(self.preprocessed_output_folder, "gt_segmentations"))
+ shutil.copytree(join(self.folder_with_cropped_data, "gt_segmentations"),
+ join(self.preprocessed_output_folder, "gt_segmentations"))
+ normalization_schemes = self.plans['normalization_schemes']
+ use_nonzero_mask_for_normalization = self.plans['use_mask_for_norm']
+ intensityproperties = self.plans['dataset_properties']['intensityproperties']
+ preprocessor_class = recursive_find_python_class([join(nnunet.__path__[0], "preprocessing")],
+ self.preprocessor_name, current_module="nnunet.preprocessing")
+ assert preprocessor_class is not None
+ preprocessor = preprocessor_class(normalization_schemes, use_nonzero_mask_for_normalization,
+ self.transpose_forward,
+ intensityproperties)
+ target_spacings = [i["current_spacing"] for i in self.plans_per_stage.values()]
+ if self.plans['num_stages'] > 1 and not isinstance(num_threads, (list, tuple)):
+ num_threads = (default_num_threads, num_threads)
+ elif self.plans['num_stages'] == 1 and isinstance(num_threads, (list, tuple)):
+ num_threads = num_threads[-1]
+ preprocessor.run(target_spacings, self.folder_with_cropped_data, self.preprocessed_output_folder,
+ self.plans['data_identifier'], num_threads)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-t", "--task_ids", nargs="+", help="list of int")
+ parser.add_argument("-p", action="store_true", help="set this if you actually want to run the preprocessing. If "
+ "this is not set then this script will only create the plans file")
+ parser.add_argument("-tl", type=int, required=False, default=8, help="num_threads_lowres")
+ parser.add_argument("-tf", type=int, required=False, default=8, help="num_threads_fullres")
+
+ args = parser.parse_args()
+ task_ids = args.task_ids
+ run_preprocessing = args.p
+ tl = args.tl
+ tf = args.tf
+
+ tasks = []
+ for i in task_ids:
+ i = int(i)
+ candidates = subdirs(nnUNet_cropped_data, prefix="Task%03.0d" % i, join=False)
+ assert len(candidates) == 1
+ tasks.append(candidates[0])
+
+ for t in tasks:
+ try:
+ print("\n\n\n", t)
+ cropped_out_dir = os.path.join(nnUNet_cropped_data, t)
+ preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t)
+ splitted_4d_output_dir_task = os.path.join(nnUNet_raw_data, t)
+ lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
+
+ dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False)
+ _ = dataset_analyzer.analyze_dataset() # this will write output files that will be used by the ExperimentPlanner
+
+ maybe_mkdir_p(preprocessing_output_dir_this_task)
+ shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task)
+ shutil.copy(join(nnUNet_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task)
+
+ threads = (tl, tf)
+
+ print("number of threads: ", threads, "\n")
+
+ exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task)
+ exp_planner.plan_experiment()
+ if run_preprocessing:
+ exp_planner.run_preprocessing(threads)
+ except Exception as e:
+ print(e)
diff --git a/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py b/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
new file mode 100644
index 0000000000000000000000000000000000000000..24faa2446e7a744b4dbe604fba0fb0a841244a27
--- /dev/null
+++ b/nnunet/experiment_planning/experiment_planner_baseline_3DUNet_v21.py
@@ -0,0 +1,179 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+import numpy as np
+from nnunet.experiment_planning.common_utils import get_pool_and_conv_props
+from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.paths import *
+
+
+class ExperimentPlanner3D_v21(ExperimentPlanner):
+ """
+ Combines ExperimentPlannerPoolBasedOnSpacing and ExperimentPlannerTargetSpacingForAnisoAxis
+
+ We also increase the base_num_features to 32. This is solely because mixed precision training with 3D convs and
+ amp is A LOT faster if the number of filters is divisible by 8
+ """
+ def __init__(self, folder_with_cropped_data, preprocessed_output_folder):
+ super(ExperimentPlanner3D_v21, self).__init__(folder_with_cropped_data, preprocessed_output_folder)
+ self.data_identifier = "nnUNetData_plans_v2.1"
+ self.plans_fname = join(self.preprocessed_output_folder,
+ "nnUNetPlansv2.1_plans_3D.pkl")
+ self.unet_base_num_features = 32
+
+ def get_target_spacing(self):
+ """
+ per default we use the 50th percentile=median for the target spacing. Higher spacing results in smaller data
+ and thus faster and easier training. Smaller spacing results in larger data and thus longer and harder training
+
+ For some datasets the median is not a good choice. Those are the datasets where the spacing is very anisotropic
+ (for example ACDC with (10, 1.5, 1.5)). These datasets still have examples with a spacing of 5 or 6 mm in the low
+ resolution axis. Choosing the median here will result in bad interpolation artifacts that can substantially
+ impact performance (due to the low number of slices).
+ """
+ spacings = self.dataset_properties['all_spacings']
+ sizes = self.dataset_properties['all_sizes']
+
+ target = np.percentile(np.vstack(spacings), self.target_spacing_percentile, 0)
+
+ # This should be used to determine the new median shape. The old implementation is not 100% correct.
+ # Fixed in 2.4
+ # sizes = [np.array(i) / target * np.array(j) for i, j in zip(spacings, sizes)]
+
+ target_size = np.percentile(np.vstack(sizes), self.target_spacing_percentile, 0)
+ target_size_mm = np.array(target) * np.array(target_size)
+ # we need to identify datasets for which a different target spacing could be beneficial. These datasets have
+ # the following properties:
+ # - one axis which much lower resolution than the others
+ # - the lowres axis has much less voxels than the others
+ # - (the size in mm of the lowres axis is also reduced)
+ worst_spacing_axis = np.argmax(target)
+ other_axes = [i for i in range(len(target)) if i != worst_spacing_axis]
+ other_spacings = [target[i] for i in other_axes]
+ other_sizes = [target_size[i] for i in other_axes]
+
+ has_aniso_spacing = target[worst_spacing_axis] > (self.anisotropy_threshold * max(other_spacings))
+ has_aniso_voxels = target_size[worst_spacing_axis] * self.anisotropy_threshold < min(other_sizes)
+ # we don't use the last one for now
+ #median_size_in_mm = target[target_size_mm] * RESAMPLING_SEPARATE_Z_ANISOTROPY_THRESHOLD < max(target_size_mm)
+
+ if has_aniso_spacing and has_aniso_voxels:
+ spacings_of_that_axis = np.vstack(spacings)[:, worst_spacing_axis]
+ target_spacing_of_that_axis = np.percentile(spacings_of_that_axis, 10)
+ # don't let the spacing of that axis get higher than the other axes
+ if target_spacing_of_that_axis < max(other_spacings):
+ target_spacing_of_that_axis = max(max(other_spacings), target_spacing_of_that_axis) + 1e-5
+ target[worst_spacing_axis] = target_spacing_of_that_axis
+ return target
+
+ def get_properties_for_stage(self, current_spacing, original_spacing, original_shape, num_cases,
+ num_modalities, num_classes):
+ """
+ ExperimentPlanner configures pooling so that we pool late. Meaning that if the number of pooling per axis is
+ (2, 3, 3), then the first pooling operation will always pool axes 1 and 2 and not 0, irrespective of spacing.
+ This can cause a larger memory footprint, so it can be beneficial to revise this.
+
+ Here we are pooling based on the spacing of the data.
+
+ """
+ new_median_shape = np.round(original_spacing / current_spacing * original_shape).astype(int)
+ dataset_num_voxels = np.prod(new_median_shape) * num_cases
+
+ # the next line is what we had before as a default. The patch size had the same aspect ratio as the median shape of a patient. We swapped t
+ # input_patch_size = new_median_shape
+
+ # compute how many voxels are one mm
+ input_patch_size = 1 / np.array(current_spacing)
+
+ # normalize voxels per mm
+ input_patch_size /= input_patch_size.mean()
+
+ # create an isotropic patch of size 512x512x512mm
+ input_patch_size *= 1 / min(input_patch_size) * 512 # to get a starting value
+ input_patch_size = np.round(input_patch_size).astype(int)
+
+ # clip it to the median shape of the dataset because patches larger then that make not much sense
+ input_patch_size = [min(i, j) for i, j in zip(input_patch_size, new_median_shape)]
+
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, input_patch_size,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool)
+
+ # we compute as if we were using only 30 feature maps. We can do that because fp16 training is the standard
+ # now. That frees up some space. The decision to go with 32 is solely due to the speedup we get (non-multiples
+ # of 8 are not supported in nvidia amp)
+ ref = Generic_UNet.use_this_for_batch_size_computation_3D * self.unet_base_num_features / \
+ Generic_UNet.BASE_NUM_FEATURES_3D
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes,
+ pool_op_kernel_sizes, conv_per_stage=self.conv_per_stage)
+ while here > ref:
+ axis_to_be_reduced = np.argsort(new_shp / new_median_shape)[-1]
+
+ tmp = deepcopy(new_shp)
+ tmp[axis_to_be_reduced] -= shape_must_be_divisible_by[axis_to_be_reduced]
+ _, _, _, _, shape_must_be_divisible_by_new = \
+ get_pool_and_conv_props(current_spacing, tmp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+ new_shp[axis_to_be_reduced] -= shape_must_be_divisible_by_new[axis_to_be_reduced]
+
+ # we have to recompute numpool now:
+ network_num_pool_per_axis, pool_op_kernel_sizes, conv_kernel_sizes, new_shp, \
+ shape_must_be_divisible_by = get_pool_and_conv_props(current_spacing, new_shp,
+ self.unet_featuremap_min_edge_length,
+ self.unet_max_numpool,
+ )
+
+ here = Generic_UNet.compute_approx_vram_consumption(new_shp, network_num_pool_per_axis,
+ self.unet_base_num_features,
+ self.unet_max_num_filters, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ conv_per_stage=self.conv_per_stage)
+ #print(new_shp)
+ #print(here, ref)
+
+ input_patch_size = new_shp
+
+ batch_size = Generic_UNet.DEFAULT_BATCH_SIZE_3D # This is what wirks with 128**3
+ batch_size = int(np.floor(max(ref / here, 1) * batch_size))
+
+ # check if batch size is too large
+ max_batch_size = np.round(self.batch_size_covers_max_percent_of_dataset * dataset_num_voxels /
+ np.prod(input_patch_size, dtype=np.int64)).astype(int)
+ max_batch_size = max(max_batch_size, self.unet_min_batch_size)
+ batch_size = max(1, min(batch_size, max_batch_size))
+
+ do_dummy_2D_data_aug = (max(input_patch_size) / input_patch_size[
+ 0]) > self.anisotropy_threshold
+
+ plan = {
+ 'batch_size': batch_size,
+ 'num_pool_per_axis': network_num_pool_per_axis,
+ 'patch_size': input_patch_size,
+ 'median_patient_size_in_voxels': new_median_shape,
+ 'current_spacing': current_spacing,
+ 'original_spacing': original_spacing,
+ 'do_dummy_2D_data_aug': do_dummy_2D_data_aug,
+ 'pool_op_kernel_sizes': pool_op_kernel_sizes,
+ 'conv_kernel_sizes': conv_kernel_sizes,
+ }
+ return plan
diff --git a/nnunet/experiment_planning/nnUNet_convert_decathlon_task.py b/nnunet/experiment_planning/nnUNet_convert_decathlon_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf5285a1da802c1980a36f83a3b810f56d63bdfb
--- /dev/null
+++ b/nnunet/experiment_planning/nnUNet_convert_decathlon_task.py
@@ -0,0 +1,64 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.configuration import default_num_threads
+from nnunet.experiment_planning.utils import split_4d
+from nnunet.utilities.file_endings import remove_trailing_slash
+
+
+def crawl_and_remove_hidden_from_decathlon(folder):
+ folder = remove_trailing_slash(folder)
+ assert folder.split('/')[-1].startswith("Task"), "This does not seem to be a decathlon folder. Please give me a " \
+ "folder that starts with TaskXX and has the subfolders imagesTr, " \
+ "labelsTr and imagesTs"
+ subf = subfolders(folder, join=False)
+ assert 'imagesTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
+ "folder that starts with TaskXX and has the subfolders imagesTr, " \
+ "labelsTr and imagesTs"
+ assert 'imagesTs' in subf, "This does not seem to be a decathlon folder. Please give me a " \
+ "folder that starts with TaskXX and has the subfolders imagesTr, " \
+ "labelsTr and imagesTs"
+ assert 'labelsTr' in subf, "This does not seem to be a decathlon folder. Please give me a " \
+ "folder that starts with TaskXX and has the subfolders imagesTr, " \
+ "labelsTr and imagesTs"
+ _ = [os.remove(i) for i in subfiles(folder, prefix=".")]
+ _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTr'), prefix=".")]
+ _ = [os.remove(i) for i in subfiles(join(folder, 'labelsTr'), prefix=".")]
+ _ = [os.remove(i) for i in subfiles(join(folder, 'imagesTs'), prefix=".")]
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(description="The MSD provides data as 4D Niftis with the modality being the first"
+ " dimension. We think this may be cumbersome for some users and "
+ "therefore expect 3D niftixs instead, with one file per modality. "
+ "This utility will convert 4D MSD data into the format nnU-Net "
+ "expects")
+ parser.add_argument("-i", help="Input folder. Must point to a TaskXX_TASKNAME folder as downloaded from the MSD "
+ "website", required=True)
+ parser.add_argument("-p", required=False, default=default_num_threads, type=int,
+ help="Use this to specify how many processes are used to run the script. "
+ "Default is %d" % default_num_threads)
+ parser.add_argument("-output_task_id", required=False, default=None, type=int,
+ help="If specified, this will overwrite the task id in the output folder. If unspecified, the "
+ "task id of the input folder will be used.")
+ args = parser.parse_args()
+
+ crawl_and_remove_hidden_from_decathlon(args.i)
+
+ split_4d(args.i, args.p, args.output_task_id)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/experiment_planning/nnUNet_plan_and_preprocess.py b/nnunet/experiment_planning/nnUNet_plan_and_preprocess.py
new file mode 100644
index 0000000000000000000000000000000000000000..97bdc11aab718971d03ad11ec8f9ae7ccf8a636e
--- /dev/null
+++ b/nnunet/experiment_planning/nnUNet_plan_and_preprocess.py
@@ -0,0 +1,171 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import nnunet
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
+from nnunet.experiment_planning.utils import crop
+from nnunet.paths import *
+import shutil
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+from nnunet.preprocessing.sanity_checks import verify_dataset_integrity
+from nnunet.training.model_restore import recursive_find_python_class
+
+
+def main():
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-t", "--task_ids", nargs="+", help="List of integers belonging to the task ids you wish to run"
+ " experiment planning and preprocessing for. Each of these "
+ "ids must, have a matching folder 'TaskXXX_' in the raw "
+ "data folder")
+ parser.add_argument("-pl3d", "--planner3d", type=str, default="ExperimentPlanner3D_v21",
+ help="Name of the ExperimentPlanner class for the full resolution 3D U-Net and U-Net cascade. "
+ "Default is ExperimentPlanner3D_v21. Can be 'None', in which case these U-Nets will not be "
+ "configured")
+ parser.add_argument("-pl2d", "--planner2d", type=str, default="ExperimentPlanner2D_v21",
+ help="Name of the ExperimentPlanner class for the 2D U-Net. Default is ExperimentPlanner2D_v21. "
+ "Can be 'None', in which case this U-Net will not be configured")
+ parser.add_argument("-no_pp", action="store_true",
+ help="Set this flag if you dont want to run the preprocessing. If this is set then this script "
+ "will only run the experiment planning and create the plans file")
+ parser.add_argument("-tl", type=int, required=False, default=8,
+ help="Number of processes used for preprocessing the low resolution data for the 3D low "
+ "resolution U-Net. This can be larger than -tf. Don't overdo it or you will run out of "
+ "RAM")
+ parser.add_argument("-tf", type=int, required=False, default=8,
+ help="Number of processes used for preprocessing the full resolution data of the 2D U-Net and "
+ "3D U-Net. Don't overdo it or you will run out of RAM")
+ parser.add_argument("--verify_dataset_integrity", required=False, default=False, action="store_true",
+ help="set this flag to check the dataset integrity. This is useful and should be done once for "
+ "each dataset!")
+ parser.add_argument("-overwrite_plans", type=str, default=None, required=False,
+ help="Use this to specify a plans file that should be used instead of whatever nnU-Net would "
+ "configure automatically. This will overwrite everything: intensity normalization, "
+ "network architecture, target spacing etc. Using this is useful for using pretrained "
+ "model weights as this will guarantee that the network architecture on the target "
+ "dataset is the same as on the source dataset and the weights can therefore be transferred.\n"
+ "Pro tip: If you want to pretrain on Hepaticvessel and apply the result to LiTS then use "
+ "the LiTS plans to run the preprocessing of the HepaticVessel task.\n"
+ "Make sure to only use plans files that were "
+ "generated with the same number of modalities as the target dataset (LiTS -> BCV or "
+ "LiTS -> Task008_HepaticVessel is OK. BraTS -> LiTS is not (BraTS has 4 input modalities, "
+ "LiTS has just one)). Also only do things that make sense. This functionality is beta with"
+ "no support given.\n"
+ "Note that this will first print the old plans (which are going to be overwritten) and "
+ "then the new ones (provided that -no_pp was NOT set).")
+ parser.add_argument("-overwrite_plans_identifier", type=str, default=None, required=False,
+ help="If you set overwrite_plans you need to provide a unique identifier so that nnUNet knows "
+ "where to look for the correct plans and data. Assume your identifier is called "
+ "IDENTIFIER, the correct training command would be:\n"
+ "'nnUNet_train CONFIG TRAINER TASKID FOLD -p nnUNetPlans_pretrained_IDENTIFIER "
+ "-pretrained_weights FILENAME'")
+
+ args = parser.parse_args()
+ task_ids = args.task_ids
+ dont_run_preprocessing = args.no_pp
+ tl = args.tl
+ tf = args.tf
+ planner_name3d = args.planner3d
+ planner_name2d = args.planner2d
+
+ if planner_name3d == "None":
+ planner_name3d = None
+ if planner_name2d == "None":
+ planner_name2d = None
+
+ if args.overwrite_plans is not None:
+ if planner_name2d is not None:
+ print("Overwriting plans only works for the 3d planner. I am setting '--planner2d' to None. This will "
+ "skip 2d planning and preprocessing.")
+ assert planner_name3d == 'ExperimentPlanner3D_v21_Pretrained', "When using --overwrite_plans you need to use " \
+ "'-pl3d ExperimentPlanner3D_v21_Pretrained'"
+
+ # we need raw data
+ tasks = []
+ for i in task_ids:
+ i = int(i)
+
+ task_name = convert_id_to_task_name(i)
+
+ if args.verify_dataset_integrity:
+ verify_dataset_integrity(join(nnUNet_raw_data, task_name))
+
+ crop(task_name, False, tf)
+
+ tasks.append(task_name)
+
+ search_in = join(nnunet.__path__[0], "experiment_planning")
+
+ if planner_name3d is not None:
+ planner_3d = recursive_find_python_class([search_in], planner_name3d, current_module="nnunet.experiment_planning")
+ if planner_3d is None:
+ raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
+ "nnunet.experiment_planning" % planner_name3d)
+ else:
+ planner_3d = None
+
+ if planner_name2d is not None:
+ planner_2d = recursive_find_python_class([search_in], planner_name2d, current_module="nnunet.experiment_planning")
+ if planner_2d is None:
+ raise RuntimeError("Could not find the Planner class %s. Make sure it is located somewhere in "
+ "nnunet.experiment_planning" % planner_name2d)
+ else:
+ planner_2d = None
+
+ for t in tasks:
+ print("\n\n\n", t)
+ cropped_out_dir = os.path.join(nnUNet_cropped_data, t)
+ preprocessing_output_dir_this_task = os.path.join(preprocessing_output_dir, t)
+ #splitted_4d_output_dir_task = os.path.join(nnUNet_raw_data, t)
+ #lists, modalities = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
+
+ # we need to figure out if we need the intensity propoerties. We collect them only if one of the modalities is CT
+ dataset_json = load_json(join(cropped_out_dir, 'dataset.json'))
+ modalities = list(dataset_json["modality"].values())
+ collect_intensityproperties = True if (("CT" in modalities) or ("ct" in modalities)) else False
+ dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=False, num_processes=tf) # this class creates the fingerprint
+ _ = dataset_analyzer.analyze_dataset(collect_intensityproperties) # this will write output files that will be used by the ExperimentPlanner
+
+
+ maybe_mkdir_p(preprocessing_output_dir_this_task)
+ shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task)
+ shutil.copy(join(nnUNet_raw_data, t, "dataset.json"), preprocessing_output_dir_this_task)
+
+ threads = (tl, tf)
+
+ print("number of threads: ", threads, "\n")
+
+ if planner_3d is not None:
+ if args.overwrite_plans is not None:
+ assert args.overwrite_plans_identifier is not None, "You need to specify -overwrite_plans_identifier"
+ exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task, args.overwrite_plans,
+ args.overwrite_plans_identifier)
+ else:
+ exp_planner = planner_3d(cropped_out_dir, preprocessing_output_dir_this_task)
+ exp_planner.plan_experiment()
+ if not dont_run_preprocessing: # double negative, yooo
+ exp_planner.run_preprocessing(threads)
+ if planner_2d is not None:
+ exp_planner = planner_2d(cropped_out_dir, preprocessing_output_dir_this_task)
+ exp_planner.plan_experiment()
+ if not dont_run_preprocessing: # double negative, yooo
+ exp_planner.run_preprocessing(threads)
+
+
+if __name__ == "__main__":
+ main()
+
diff --git a/nnunet/experiment_planning/old/__init__.py b/nnunet/experiment_planning/old/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/experiment_planning/old/old_plan_and_preprocess_task.py b/nnunet/experiment_planning/old/old_plan_and_preprocess_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..b26c4ae74100e6206e6b6fbfe7d7454e88c26b96
--- /dev/null
+++ b/nnunet/experiment_planning/old/old_plan_and_preprocess_task.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nnunet.experiment_planning.utils import split_4d, crop, analyze_dataset, plan_and_preprocess
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import nnUNet_raw_data
+
+if __name__ == "__main__":
+ import argparse
+ parser = argparse.ArgumentParser()
+ parser.add_argument('-t', '--task', type=str, help="task name. There must be a matching folder in "
+ "raw_dataset_dir", required=True)
+ parser.add_argument('-pl', '--processes_lowres', type=int, default=8, help='number of processes used for '
+ 'preprocessing 3d_lowres data, image '
+ 'splitting and image cropping '
+ 'Default: 8. The distinction between '
+ 'processes_lowres and processes_fullres '
+ 'is necessary because preprocessing '
+ 'at full resolution needs a lot of '
+ 'RAM', required=False)
+ parser.add_argument('-pf', '--processes_fullres', type=int, default=8, help='number of processes used for '
+ 'preprocessing 2d and 3d_fullres '
+ 'data. Default: 3', required=False)
+ parser.add_argument('-o', '--override', type=int, default=0, help="set this to 1 if you want to override "
+ "cropped data and intensityproperties. Default: 0",
+ required=False)
+ parser.add_argument('-s', '--use_splitted', type=int, default=1, help='1 = use splitted data if already present ('
+ 'skip split_4d). 0 = do splitting again. '
+ 'It is save to set this to 1 at all times '
+ 'unless the dataset was updated in the '
+ 'meantime. Default: 1', required=False)
+ parser.add_argument('-no_preprocessing', type=int, default=0, help='debug only. If set to 1 this will run only'
+ 'experiment planning and not run the '
+ 'preprocessing')
+
+ args = parser.parse_args()
+ task = args.task
+ processes_lowres = args.processes_lowres
+ processes_fullres = args.processes_fullres
+ override = args.override
+ use_splitted = args.use_splitted
+ no_preprocessing = args.no_preprocessing
+
+ if override == 0:
+ override = False
+ elif override == 1:
+ override = True
+ else:
+ raise ValueError("only 0 or 1 allowed for override")
+
+ if no_preprocessing == 0:
+ no_preprocessing = False
+ elif no_preprocessing == 1:
+ no_preprocessing = True
+ else:
+ raise ValueError("only 0 or 1 allowed for override")
+
+ if use_splitted == 0:
+ use_splitted = False
+ elif use_splitted == 1:
+ use_splitted = True
+ else:
+ raise ValueError("only 0 or 1 allowed for use_splitted")
+
+ if task == "all":
+ all_tasks = subdirs(nnUNet_raw_data, prefix="Task", join=False)
+ for t in all_tasks:
+ crop(t, override=override, num_threads=processes_lowres)
+ analyze_dataset(t, override=override, collect_intensityproperties=True, num_processes=processes_lowres)
+ plan_and_preprocess(t, processes_lowres, processes_fullres, no_preprocessing)
+ else:
+ if not use_splitted or not isdir(join(nnUNet_raw_data, task)):
+ print("splitting task ", task)
+ split_4d(task)
+
+ crop(task, override=override, num_threads=processes_lowres)
+ analyze_dataset(task, override, collect_intensityproperties=True, num_processes=processes_lowres)
+ plan_and_preprocess(task, processes_lowres, processes_fullres, no_preprocessing)
diff --git a/nnunet/experiment_planning/summarize_plans.py b/nnunet/experiment_planning/summarize_plans.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c58c39f42a070f26bb8ee651a6c664c6d73553e
--- /dev/null
+++ b/nnunet/experiment_planning/summarize_plans.py
@@ -0,0 +1,79 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import preprocessing_output_dir
+
+
+# This file is intended to double check nnUNets design choices. It is intended to be used for developent purposes only
+def summarize_plans(file):
+ plans = load_pickle(file)
+ print("num_classes: ", plans['num_classes'])
+ print("modalities: ", plans['modalities'])
+ print("use_mask_for_norm", plans['use_mask_for_norm'])
+ print("keep_only_largest_region", plans['keep_only_largest_region'])
+ print("min_region_size_per_class", plans['min_region_size_per_class'])
+ print("min_size_per_class", plans['min_size_per_class'])
+ print("normalization_schemes", plans['normalization_schemes'])
+ print("stages...\n")
+
+ for i in range(len(plans['plans_per_stage'])):
+ print("stage: ", i)
+ print(plans['plans_per_stage'][i])
+ print("")
+
+
+def write_plans_to_file(f, plans_file):
+ print(plans_file)
+ a = load_pickle(plans_file)
+ stages = list(a['plans_per_stage'].keys())
+ stages.sort()
+ for stage in stages:
+ patch_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['patch_size'],
+ a['plans_per_stage'][stages[stage]]['current_spacing'])]
+ median_patient_size_in_mm = [i * j for i, j in zip(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels'],
+ a['plans_per_stage'][stages[stage]]['current_spacing'])]
+ f.write(plans_file.split("/")[-2])
+ f.write(";%s" % plans_file.split("/")[-1])
+ f.write(";%d" % stage)
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['batch_size']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['num_pool_per_axis']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['patch_size']))
+ f.write(";%s" % str([str("%03.2f" % i) for i in patch_size_in_mm]))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['median_patient_size_in_voxels']))
+ f.write(";%s" % str([str("%03.2f" % i) for i in median_patient_size_in_mm]))
+ f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['current_spacing']]))
+ f.write(";%s" % str([str("%03.2f" % i) for i in a['plans_per_stage'][stages[stage]]['original_spacing']]))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['pool_op_kernel_sizes']))
+ f.write(";%s" % str(a['plans_per_stage'][stages[stage]]['conv_kernel_sizes']))
+ f.write(";%s" % str(a['data_identifier']))
+ f.write("\n")
+
+
+if __name__ == "__main__":
+ base_dir = './'#preprocessing_output_dir''
+ task_dirs = [i for i in subdirs(base_dir, join=False, prefix="Task") if i.find("BrainTumor") == -1 and i.find("MSSeg") == -1]
+ print("found %d tasks" % len(task_dirs))
+
+ with open("2019_02_06_plans_summary.csv", 'w') as f:
+ f.write("task;plans_file;stage;batch_size;num_pool_per_axis;patch_size;patch_size(mm);median_patient_size_in_voxels;median_patient_size_in_mm;current_spacing;original_spacing;pool_op_kernel_sizes;conv_kernel_sizes\n")
+ for t in task_dirs:
+ print(t)
+ tmp = join(base_dir, t)
+ plans_files = [i for i in subfiles(tmp, suffix=".pkl", join=False) if i.find("_plans_") != -1 and i.find("Dgx2") == -1]
+ for p in plans_files:
+ write_plans_to_file(f, join(tmp, p))
+ f.write("\n")
+
+
diff --git a/nnunet/experiment_planning/utils.py b/nnunet/experiment_planning/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c118b18f2f6719231e43a1dea3402e22928b6a9
--- /dev/null
+++ b/nnunet/experiment_planning/utils.py
@@ -0,0 +1,222 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import pickle
+import shutil
+from collections import OrderedDict
+from multiprocessing import Pool
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import join, isdir, maybe_mkdir_p, subfiles, subdirs, isfile
+from nnunet.configuration import default_num_threads
+from nnunet.experiment_planning.DatasetAnalyzer import DatasetAnalyzer
+from nnunet.experiment_planning.common_utils import split_4d_nifti
+from nnunet.paths import nnUNet_raw_data, nnUNet_cropped_data, preprocessing_output_dir
+from nnunet.preprocessing.cropping import ImageCropper
+
+
+def split_4d(input_folder, num_processes=default_num_threads, overwrite_task_output_id=None):
+ assert isdir(join(input_folder, "imagesTr")) and isdir(join(input_folder, "labelsTr")) and \
+ isfile(join(input_folder, "dataset.json")), \
+ "The input folder must be a valid Task folder from the Medical Segmentation Decathlon with at least the " \
+ "imagesTr and labelsTr subfolders and the dataset.json file"
+
+ while input_folder.endswith("/"):
+ input_folder = input_folder[:-1]
+
+ full_task_name = input_folder.split("/")[-1]
+
+ assert full_task_name.startswith("Task"), "The input folder must point to a folder that starts with TaskXX_"
+
+ first_underscore = full_task_name.find("_")
+ assert first_underscore == 6, "Input folder start with TaskXX with XX being a 3-digit id: 00, 01, 02 etc"
+
+ input_task_id = int(full_task_name[4:6])
+ if overwrite_task_output_id is None:
+ overwrite_task_output_id = input_task_id
+
+ task_name = full_task_name[7:]
+
+ output_folder = join(nnUNet_raw_data, "Task%03.0d_" % overwrite_task_output_id + task_name)
+
+ if isdir(output_folder):
+ shutil.rmtree(output_folder)
+
+ files = []
+ output_dirs = []
+
+ maybe_mkdir_p(output_folder)
+ for subdir in ["imagesTr", "imagesTs"]:
+ curr_out_dir = join(output_folder, subdir)
+ if not isdir(curr_out_dir):
+ os.mkdir(curr_out_dir)
+ curr_dir = join(input_folder, subdir)
+ nii_files = [join(curr_dir, i) for i in os.listdir(curr_dir) if i.endswith(".nii.gz")]
+ nii_files.sort()
+ for n in nii_files:
+ files.append(n)
+ output_dirs.append(curr_out_dir)
+
+ shutil.copytree(join(input_folder, "labelsTr"), join(output_folder, "labelsTr"))
+
+ p = Pool(num_processes)
+ p.starmap(split_4d_nifti, zip(files, output_dirs))
+ p.close()
+ p.join()
+ shutil.copy(join(input_folder, "dataset.json"), output_folder)
+
+
+def create_lists_from_splitted_dataset(base_folder_splitted):
+ lists = []
+
+ json_file = join(base_folder_splitted, "dataset.json")
+ with open(json_file) as jsn:
+ d = json.load(jsn)
+ training_files = d['training']
+ num_modalities = len(d['modality'].keys())
+ for tr in training_files:
+ cur_pat = []
+ for mod in range(num_modalities):
+ cur_pat.append(join(base_folder_splitted, "imagesTr", tr['image'].split("/")[-1][:-7] +
+ "_%04.0d.nii.gz" % mod))
+ cur_pat.append(join(base_folder_splitted, "labelsTr", tr['label'].split("/")[-1]))
+ lists.append(cur_pat)
+ return lists, {int(i): d['modality'][str(i)] for i in d['modality'].keys()}
+
+
+def create_lists_from_splitted_dataset_folder(folder):
+ """
+ does not rely on dataset.json
+ :param folder:
+ :return:
+ """
+ caseIDs = get_caseIDs_from_splitted_dataset_folder(folder)
+ list_of_lists = []
+ for f in caseIDs:
+ list_of_lists.append(subfiles(folder, prefix=f, suffix=".nii.gz", join=True, sort=True))
+ return list_of_lists
+
+
+def get_caseIDs_from_splitted_dataset_folder(folder):
+ files = subfiles(folder, suffix=".nii.gz", join=False)
+ # all files must be .nii.gz and have 4 digit modality index
+ files = [i[:-12] for i in files]
+ # only unique patient ids
+ files = np.unique(files)
+ return files
+
+
+def crop(task_string, override=False, num_threads=default_num_threads):
+ cropped_out_dir = join(nnUNet_cropped_data, task_string)
+ maybe_mkdir_p(cropped_out_dir)
+
+ if override and isdir(cropped_out_dir):
+ shutil.rmtree(cropped_out_dir)
+ maybe_mkdir_p(cropped_out_dir)
+
+ splitted_4d_output_dir_task = join(nnUNet_raw_data, task_string)
+ lists, _ = create_lists_from_splitted_dataset(splitted_4d_output_dir_task)
+
+ imgcrop = ImageCropper(num_threads, cropped_out_dir)
+ imgcrop.run_cropping(lists, overwrite_existing=override)
+ shutil.copy(join(nnUNet_raw_data, task_string, "dataset.json"), cropped_out_dir)
+
+
+def analyze_dataset(task_string, override=False, collect_intensityproperties=True, num_processes=default_num_threads):
+ cropped_out_dir = join(nnUNet_cropped_data, task_string)
+ dataset_analyzer = DatasetAnalyzer(cropped_out_dir, overwrite=override, num_processes=num_processes)
+ _ = dataset_analyzer.analyze_dataset(collect_intensityproperties)
+
+
+def plan_and_preprocess(task_string, processes_lowres=default_num_threads, processes_fullres=3, no_preprocessing=False):
+ from nnunet.experiment_planning.experiment_planner_baseline_2DUNet import ExperimentPlanner2D
+ from nnunet.experiment_planning.experiment_planner_baseline_3DUNet import ExperimentPlanner
+
+ preprocessing_output_dir_this_task_train = join(preprocessing_output_dir, task_string)
+ cropped_out_dir = join(nnUNet_cropped_data, task_string)
+ maybe_mkdir_p(preprocessing_output_dir_this_task_train)
+
+ shutil.copy(join(cropped_out_dir, "dataset_properties.pkl"), preprocessing_output_dir_this_task_train)
+ shutil.copy(join(nnUNet_raw_data, task_string, "dataset.json"), preprocessing_output_dir_this_task_train)
+
+ exp_planner = ExperimentPlanner(cropped_out_dir, preprocessing_output_dir_this_task_train)
+ exp_planner.plan_experiment()
+ if not no_preprocessing:
+ exp_planner.run_preprocessing((processes_lowres, processes_fullres))
+
+ exp_planner = ExperimentPlanner2D(cropped_out_dir, preprocessing_output_dir_this_task_train)
+ exp_planner.plan_experiment()
+ if not no_preprocessing:
+ exp_planner.run_preprocessing(processes_fullres)
+
+ # write which class is in which slice to all training cases (required to speed up 2D Dataloader)
+ # This is done for all data so that if we wanted to use them with 2D we could do so
+
+ if not no_preprocessing:
+ p = Pool(default_num_threads)
+
+ # if there is more than one my_data_identifier (different brnaches) then this code will run for all of them if
+ # they start with the same string. not problematic, but not pretty
+ stages = [i for i in subdirs(preprocessing_output_dir_this_task_train, join=True, sort=True)
+ if i.split("/")[-1].find("stage") != -1]
+ for s in stages:
+ print(s.split("/")[-1])
+ list_of_npz_files = subfiles(s, True, None, ".npz", True)
+ list_of_pkl_files = [i[:-4]+".pkl" for i in list_of_npz_files]
+ all_classes = []
+ for pk in list_of_pkl_files:
+ with open(pk, 'rb') as f:
+ props = pickle.load(f)
+ all_classes_tmp = np.array(props['classes'])
+ all_classes.append(all_classes_tmp[all_classes_tmp >= 0])
+ p.map(add_classes_in_slice_info, zip(list_of_npz_files, list_of_pkl_files, all_classes))
+ p.close()
+ p.join()
+
+
+def add_classes_in_slice_info(args):
+ """
+ We need this for 2D dataloader with oversampling. As of now it will detect slices that contain specific classes
+ at run time, meaning it needs to iterate over an entire patient just to extract one slice. That is obviously bad,
+ so we are doing this once beforehand and just give the dataloader the info it needs in the patients pkl file.
+
+ """
+ npz_file, pkl_file, all_classes = args
+ seg_map = np.load(npz_file)['data'][-1]
+ with open(pkl_file, 'rb') as f:
+ props = pickle.load(f)
+ #if props.get('classes_in_slice_per_axis') is not None:
+ print(pkl_file)
+ # this will be a dict of dict where the first dict encodes the axis along which a slice is extracted in its keys.
+ # The second dict (value of first dict) will have all classes as key and as values a list of all slice ids that
+ # contain this class
+ classes_in_slice = OrderedDict()
+ for axis in range(3):
+ other_axes = tuple([i for i in range(3) if i != axis])
+ classes_in_slice[axis] = OrderedDict()
+ for c in all_classes:
+ valid_slices = np.where(np.sum(seg_map == c, axis=other_axes) > 0)[0]
+ classes_in_slice[axis][c] = valid_slices
+
+ number_of_voxels_per_class = OrderedDict()
+ for c in all_classes:
+ number_of_voxels_per_class[c] = np.sum(seg_map == c)
+
+ props['classes_in_slice_per_axis'] = classes_in_slice
+ props['number_of_voxels_per_class'] = number_of_voxels_per_class
+
+ with open(pkl_file, 'wb') as f:
+ pickle.dump(props, f)
diff --git a/nnunet/inference/__init__.py b/nnunet/inference/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/inference/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/inference/change_trainer.py b/nnunet/inference/change_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f319ac103068b51b1e72d6479390439eb5b3564a
--- /dev/null
+++ b/nnunet/inference/change_trainer.py
@@ -0,0 +1,51 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def pretend_to_be_nnUNetTrainer(folder, checkpoints=("model_best.model.pkl", "model_final_checkpoint.model.pkl")):
+ pretend_to_be_other_trainer(folder, "nnUNetTrainer", checkpoints)
+
+
+def pretend_to_be_other_trainer(folder, new_trainer_name, checkpoints=("model_best.model.pkl", "model_final_checkpoint.model.pkl")):
+ folds = subdirs(folder, prefix="fold_", join=False)
+
+ if isdir(join(folder, 'all')):
+ folds.append('all')
+
+ for c in checkpoints:
+ for f in folds:
+ checkpoint_file = join(folder, f, c)
+ if isfile(checkpoint_file):
+ a = load_pickle(checkpoint_file)
+ a['name'] = new_trainer_name
+ save_pickle(a, checkpoint_file)
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(description='Use this script to change the nnunet trainer class of a saved '
+ 'model. Useful for models that were trained with trainers that do '
+ 'not support inference (multi GPU trainers) or for trainer classes '
+ 'whose source code is not available. For this to work the network '
+ 'architecture must be identical between the original trainer '
+ 'class and the trainer class we are changing to. This script is '
+ 'experimental and only to be used by advanced users.')
+ parser.add_argument('-i', help='Folder containing the trained model. This folder is the one containing the '
+ 'fold_X subfolders.')
+ parser.add_argument('-tr', help='Name of the new trainer class')
+ args = parser.parse_args()
+ pretend_to_be_other_trainer(args.i, args.tr)
diff --git a/nnunet/inference/ensemble_predictions.py b/nnunet/inference/ensemble_predictions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0d39a3164c059c71df7fc089bc85d19919f87bf
--- /dev/null
+++ b/nnunet/inference/ensemble_predictions.py
@@ -0,0 +1,128 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+from copy import deepcopy
+
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+from multiprocessing import Pool
+from nnunet.postprocessing.connected_components import apply_postprocessing_to_folder, load_postprocessing
+
+
+def merge_files(files, properties_files, out_file, override, store_npz):
+ if override or not isfile(out_file):
+ softmax = [np.load(f)['softmax'][None] for f in files]
+ softmax = np.vstack(softmax)
+ softmax = np.mean(softmax, 0)
+ props = [load_pickle(f) for f in properties_files]
+
+ reg_class_orders = [p['regions_class_order'] if 'regions_class_order' in p.keys() else None
+ for p in props]
+
+ if not all([i is None for i in reg_class_orders]):
+ # if reg_class_orders are not None then they must be the same in all pkls
+ tmp = reg_class_orders[0]
+ for r in reg_class_orders[1:]:
+ assert tmp == r, 'If merging files with regions_class_order, the regions_class_orders of all ' \
+ 'files must be the same. regions_class_order: %s, \n files: %s' % \
+ (str(reg_class_orders), str(files))
+ regions_class_order = tmp
+ else:
+ regions_class_order = None
+
+ # Softmax probabilities are already at target spacing so this will not do any resampling (resampling parameters
+ # don't matter here)
+ save_segmentation_nifti_from_softmax(softmax, out_file, props[0], 3, regions_class_order, None, None,
+ force_separate_z=None)
+ if store_npz:
+ np.savez_compressed(out_file[:-7] + ".npz", softmax=softmax)
+ save_pickle(props, out_file[:-7] + ".pkl")
+
+
+def merge(folders, output_folder, threads, override=True, postprocessing_file=None, store_npz=False):
+ maybe_mkdir_p(output_folder)
+
+ if postprocessing_file is not None:
+ output_folder_orig = deepcopy(output_folder)
+ output_folder = join(output_folder, 'not_postprocessed')
+ maybe_mkdir_p(output_folder)
+ else:
+ output_folder_orig = None
+
+ patient_ids = [subfiles(i, suffix=".npz", join=False) for i in folders]
+ patient_ids = [i for j in patient_ids for i in j]
+ patient_ids = [i[:-4] for i in patient_ids]
+ patient_ids = np.unique(patient_ids)
+
+ for f in folders:
+ assert all([isfile(join(f, i + ".npz")) for i in patient_ids]), "Not all patient npz are available in " \
+ "all folders"
+ assert all([isfile(join(f, i + ".pkl")) for i in patient_ids]), "Not all patient pkl are available in " \
+ "all folders"
+
+ files = []
+ property_files = []
+ out_files = []
+ for p in patient_ids:
+ files.append([join(f, p + ".npz") for f in folders])
+ property_files.append([join(f, p + ".pkl") for f in folders])
+ out_files.append(join(output_folder, p + ".nii.gz"))
+
+ p = Pool(threads)
+ p.starmap(merge_files, zip(files, property_files, out_files, [override] * len(out_files), [store_npz] * len(out_files)))
+ p.close()
+ p.join()
+
+ if postprocessing_file is not None:
+ for_which_classes, min_valid_obj_size = load_postprocessing(postprocessing_file)
+ print('Postprocessing...')
+ apply_postprocessing_to_folder(output_folder, output_folder_orig,
+ for_which_classes, min_valid_obj_size, threads)
+ shutil.copy(postprocessing_file, output_folder_orig)
+
+
+def main():
+ import argparse
+ parser = argparse.ArgumentParser(description="This script will merge predictions (that were prdicted with the "
+ "-npz option!). You need to specify a postprocessing file so that "
+ "we know here what postprocessing must be applied. Failing to do so "
+ "will disable postprocessing")
+ parser.add_argument('-f', '--folders', nargs='+', help="list of folders to merge. All folders must contain npz "
+ "files", required=True)
+ parser.add_argument('-o', '--output_folder', help="where to save the results", required=True, type=str)
+ parser.add_argument('-t', '--threads', help="number of threads used to saving niftis", required=False, default=2,
+ type=int)
+ parser.add_argument('-pp', '--postprocessing_file', help="path to the file where the postprocessing configuration "
+ "is stored. If this is not provided then no postprocessing "
+ "will be made. It is strongly recommended to provide the "
+ "postprocessing file!",
+ required=False, type=str, default=None)
+ parser.add_argument('--npz', action="store_true", required=False, help="stores npz and pkl")
+
+ args = parser.parse_args()
+
+ folders = args.folders
+ threads = args.threads
+ output_folder = args.output_folder
+ pp_file = args.postprocessing_file
+ npz = args.npz
+
+ merge(folders, output_folder, threads, override=True, postprocessing_file=pp_file, store_npz=npz)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/inference/predict.py b/nnunet/inference/predict.py
new file mode 100644
index 0000000000000000000000000000000000000000..e02b5f37b46956c830b5a39ca25bcbda4eedaac9
--- /dev/null
+++ b/nnunet/inference/predict.py
@@ -0,0 +1,837 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+from copy import deepcopy
+from typing import Tuple, Union, List
+
+import numpy as np
+from batchgenerators.augmentations.utils import resize_segmentation
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax, save_segmentation_nifti
+from batchgenerators.utilities.file_and_folder_operations import *
+from multiprocessing import Process, Queue
+import torch
+import SimpleITK as sitk
+import shutil
+from multiprocessing import Pool
+from nnunet.postprocessing.connected_components import load_remove_save, load_postprocessing
+from nnunet.training.model_restore import load_model_and_checkpoint_files
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.utilities.one_hot_encoding import to_one_hot
+
+
+def preprocess_save_to_queue(preprocess_fn, q, list_of_lists, output_files, segs_from_prev_stage, classes,
+ transpose_forward):
+ # suppress output
+ # sys.stdout = open(os.devnull, 'w')
+
+ errors_in = []
+ for i, l in enumerate(list_of_lists):
+ try:
+ output_file = output_files[i]
+ print("preprocessing", output_file)
+ d, _, dct = preprocess_fn(l)
+ # print(output_file, dct)
+ if segs_from_prev_stage[i] is not None:
+ assert isfile(segs_from_prev_stage[i]) and segs_from_prev_stage[i].endswith(
+ ".nii.gz"), "segs_from_prev_stage" \
+ " must point to a " \
+ "segmentation file"
+ seg_prev = sitk.GetArrayFromImage(sitk.ReadImage(segs_from_prev_stage[i]))
+ # check to see if shapes match
+ img = sitk.GetArrayFromImage(sitk.ReadImage(l[0]))
+ assert all([i == j for i, j in zip(seg_prev.shape, img.shape)]), "image and segmentation from previous " \
+ "stage don't have the same pixel array " \
+ "shape! image: %s, seg_prev: %s" % \
+ (l[0], segs_from_prev_stage[i])
+ seg_prev = seg_prev.transpose(transpose_forward)
+ seg_reshaped = resize_segmentation(seg_prev, d.shape[1:], order=1)
+ seg_reshaped = to_one_hot(seg_reshaped, classes)
+ d = np.vstack((d, seg_reshaped)).astype(np.float32)
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ print(d.shape)
+ if np.prod(d.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save, 4 because float32 is 4 bytes
+ print(
+ "This output is too large for python process-process communication. "
+ "Saving output temporarily to disk")
+ np.save(output_file[:-7] + ".npy", d)
+ d = output_file[:-7] + ".npy"
+ q.put((output_file, (d, dct)))
+ except KeyboardInterrupt:
+ raise KeyboardInterrupt
+ except Exception as e:
+ print("error in", l)
+ print(e)
+ q.put("end")
+ if len(errors_in) > 0:
+ print("There were some errors in the following cases:", errors_in)
+ print("These cases were ignored.")
+ else:
+ print("This worker has ended successfully, no errors to report")
+ # restore output
+ # sys.stdout = sys.__stdout__
+
+
+def preprocess_multithreaded(trainer, list_of_lists, output_files, num_processes=2, segs_from_prev_stage=None):
+ if segs_from_prev_stage is None:
+ segs_from_prev_stage = [None] * len(list_of_lists)
+
+ num_processes = min(len(list_of_lists), num_processes)
+
+ classes = list(range(1, trainer.num_classes))
+ assert isinstance(trainer, nnUNetTrainer)
+ q = Queue(1)
+ processes = []
+ for i in range(num_processes):
+ pr = Process(target=preprocess_save_to_queue, args=(trainer.preprocess_patient, q,
+ list_of_lists[i::num_processes],
+ output_files[i::num_processes],
+ segs_from_prev_stage[i::num_processes],
+ classes, trainer.plans['transpose_forward']))
+ pr.start()
+ processes.append(pr)
+
+ try:
+ end_ctr = 0
+ while end_ctr != num_processes:
+ item = q.get()
+ if item == "end":
+ end_ctr += 1
+ continue
+ else:
+ yield item
+
+ finally:
+ for p in processes:
+ if p.is_alive():
+ p.terminate() # this should not happen but better safe than sorry right
+ p.join()
+
+ q.close()
+
+
+def predict_cases(model, list_of_lists, output_filenames, folds, save_npz, num_threads_preprocessing,
+ num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True,
+ overwrite_existing=False,
+ all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint",
+ segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False):
+ """
+ :param segmentation_export_kwargs:
+ :param model: folder where the model is saved, must contain fold_x subfolders
+ :param list_of_lists: [[case0_0000.nii.gz, case0_0001.nii.gz], [case1_0000.nii.gz, case1_0001.nii.gz], ...]
+ :param output_filenames: [output_file_case0.nii.gz, output_file_case1.nii.gz, ...]
+ :param folds: default: (0, 1, 2, 3, 4) (but can also be 'all' or a subset of the five folds, for example use (0, )
+ for using only fold_0
+ :param save_npz: default: False
+ :param num_threads_preprocessing:
+ :param num_threads_nifti_save:
+ :param segs_from_prev_stage:
+ :param do_tta: default: True, can be set to False for a 8x speedup at the cost of a reduced segmentation quality
+ :param overwrite_existing: default: True
+ :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
+ :return:
+ """
+ assert len(list_of_lists) == len(output_filenames)
+ if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
+
+ pool = Pool(num_threads_nifti_save)
+ results = []
+
+ cleaned_output_files = []
+ for o in output_filenames:
+ dr, f = os.path.split(o)
+ if len(dr) > 0:
+ maybe_mkdir_p(dr)
+ if not f.endswith(".nii.gz"):
+ f, _ = os.path.splitext(f)
+ f = f + ".nii.gz"
+ cleaned_output_files.append(join(dr, f))
+
+ if not overwrite_existing:
+ print("number of cases:", len(list_of_lists))
+ # if save_npz=True then we should also check for missing npz files
+ not_done_idx = [i for i, j in enumerate(cleaned_output_files) if (not isfile(j)) or (save_npz and not isfile(j[:-7] + '.npz'))]
+
+ cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
+ list_of_lists = [list_of_lists[i] for i in not_done_idx]
+ if segs_from_prev_stage is not None:
+ segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
+
+ print("number of cases that still need to be predicted:", len(cleaned_output_files))
+
+ print("emptying cuda cache")
+ torch.cuda.empty_cache()
+
+ print("loading parameters for folds,", folds)
+ trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision,
+ checkpoint_name=checkpoint_name)
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in trainer.plans.keys():
+ force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ print("starting preprocessing generator")
+ preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
+ segs_from_prev_stage)
+ print("starting prediction...")
+ all_output_files = []
+ for preprocessed in preprocessing:
+ output_filename, (d, dct) = preprocessed
+ all_output_files.append(all_output_files)
+ if isinstance(d, str):
+ data = np.load(d)
+ os.remove(d)
+ d = data
+
+ print("predicting", output_filename)
+ trainer.load_checkpoint_ram(params[0], False)
+ softmax = trainer.predict_preprocessed_data_return_seg_and_softmax(
+ d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True,
+ step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu,
+ mixed_precision=mixed_precision)[1]
+
+ for p in params[1:]:
+ trainer.load_checkpoint_ram(p, False)
+ softmax += trainer.predict_preprocessed_data_return_seg_and_softmax(
+ d, do_mirroring=do_tta, mirror_axes=trainer.data_aug_params['mirror_axes'], use_sliding_window=True,
+ step_size=step_size, use_gaussian=True, all_in_gpu=all_in_gpu,
+ mixed_precision=mixed_precision)[1]
+
+ if len(params) > 1:
+ softmax /= len(params)
+
+ transpose_forward = trainer.plans.get('transpose_forward')
+ if transpose_forward is not None:
+ transpose_backward = trainer.plans.get('transpose_backward')
+ softmax = softmax.transpose([0] + [i + 1 for i in transpose_backward])
+
+ if save_npz:
+ npz_file = output_filename[:-7] + ".npz"
+ else:
+ npz_file = None
+
+ if hasattr(trainer, 'regions_class_order'):
+ region_class_order = trainer.regions_class_order
+ else:
+ region_class_order = None
+
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ bytes_per_voxel = 4
+ if all_in_gpu:
+ bytes_per_voxel = 2 # if all_in_gpu then the return value is half (float16)
+ if np.prod(softmax.shape) > (2e9 / bytes_per_voxel * 0.85): # * 0.85 just to be save
+ print(
+ "This output is too large for python process-process communication. Saving output temporarily to disk")
+ np.save(output_filename[:-7] + ".npy", softmax)
+ softmax = output_filename[:-7] + ".npy"
+
+ results.append(pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax, output_filename, dct, interpolation_order, region_class_order,
+ None, None,
+ npz_file, None, force_separate_z, interpolation_order_z),)
+ ))
+
+ print("inference done. Now waiting for the segmentation export to finish...")
+ _ = [i.get() for i in results]
+ # now apply postprocessing
+ # first load the postprocessing properties if they are present. Else raise a well visible warning
+ if not disable_postprocessing:
+ results = []
+ pp_file = join(model, "postprocessing.json")
+ if isfile(pp_file):
+ print("postprocessing...")
+ shutil.copy(pp_file, os.path.abspath(os.path.dirname(output_filenames[0])))
+ # for_which_classes stores for which of the classes everything but the largest connected component needs to be
+ # removed
+ for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
+ results.append(pool.starmap_async(load_remove_save,
+ zip(output_filenames, output_filenames,
+ [for_which_classes] * len(output_filenames),
+ [min_valid_obj_size] * len(output_filenames))))
+ _ = [i.get() for i in results]
+ else:
+ print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
+ "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
+ "%s" % model)
+
+ pool.close()
+ pool.join()
+
+
+def predict_cases_fast(model, list_of_lists, output_filenames, folds, num_threads_preprocessing,
+ num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True,
+ overwrite_existing=False,
+ all_in_gpu=False, step_size=0.5, checkpoint_name="model_final_checkpoint",
+ segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False):
+ assert len(list_of_lists) == len(output_filenames)
+ if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
+
+ pool = Pool(num_threads_nifti_save)
+ results = []
+
+ cleaned_output_files = []
+ for o in output_filenames:
+ dr, f = os.path.split(o)
+ if len(dr) > 0:
+ maybe_mkdir_p(dr)
+ if not f.endswith(".nii.gz"):
+ f, _ = os.path.splitext(f)
+ f = f + ".nii.gz"
+ cleaned_output_files.append(join(dr, f))
+
+ if not overwrite_existing:
+ print("number of cases:", len(list_of_lists))
+ not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)]
+
+ cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
+ list_of_lists = [list_of_lists[i] for i in not_done_idx]
+ if segs_from_prev_stage is not None:
+ segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
+
+ print("number of cases that still need to be predicted:", len(cleaned_output_files))
+
+ print("emptying cuda cache")
+ torch.cuda.empty_cache()
+
+ print("loading parameters for folds,", folds)
+ trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision,
+ checkpoint_name=checkpoint_name)
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in trainer.plans.keys():
+ force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ print("starting preprocessing generator")
+ preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
+ segs_from_prev_stage)
+
+ print("starting prediction...")
+ for preprocessed in preprocessing:
+ print("getting data from preprocessor")
+ output_filename, (d, dct) = preprocessed
+ print("got something")
+ if isinstance(d, str):
+ print("what I got is a string, so I need to load a file")
+ data = np.load(d)
+ os.remove(d)
+ d = data
+
+ # preallocate the output arrays
+ # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time)
+ softmax_aggr = None # np.zeros((trainer.num_classes, *d.shape[1:]), dtype=np.float16)
+ all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int)
+ print("predicting", output_filename)
+
+ for i, p in enumerate(params):
+ trainer.load_checkpoint_ram(p, False)
+
+ res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta,
+ mirror_axes=trainer.data_aug_params['mirror_axes'],
+ use_sliding_window=True,
+ step_size=step_size, use_gaussian=True,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=mixed_precision)
+
+ if len(params) > 1:
+ # otherwise we dont need this and we can save ourselves the time it takes to copy that
+ print("aggregating softmax")
+ if softmax_aggr is None:
+ softmax_aggr = res[1]
+ else:
+ softmax_aggr += res[1]
+ all_seg_outputs[i] = res[0]
+
+ print("obtaining segmentation map")
+ if len(params) > 1:
+ # we dont need to normalize the softmax by 1 / len(params) because this would not change the outcome of the argmax
+ seg = softmax_aggr.argmax(0)
+ else:
+ seg = all_seg_outputs[0]
+
+ print("applying transpose_backward")
+ transpose_forward = trainer.plans.get('transpose_forward')
+ if transpose_forward is not None:
+ transpose_backward = trainer.plans.get('transpose_backward')
+ seg = seg.transpose([i for i in transpose_backward])
+
+ if hasattr(trainer, 'regions_class_order'):
+ region_class_order = trainer.regions_class_order
+ else:
+ region_class_order = None
+ assert region_class_order is None, "predict_cases_fast can only work with regular softmax predictions " \
+ "and is therefore unable to handle trainer classes with region_class_order"
+
+ print("initializing segmentation export")
+ results.append(pool.starmap_async(save_segmentation_nifti,
+ ((seg, output_filename, dct, interpolation_order, force_separate_z,
+ interpolation_order_z),)
+ ))
+ print("done")
+
+ print("inference done. Now waiting for the segmentation export to finish...")
+ _ = [i.get() for i in results]
+ # now apply postprocessing
+ # first load the postprocessing properties if they are present. Else raise a well visible warning
+
+ if not disable_postprocessing:
+ results = []
+ pp_file = join(model, "postprocessing.json")
+ if isfile(pp_file):
+ print("postprocessing...")
+ shutil.copy(pp_file, os.path.dirname(output_filenames[0]))
+ # for_which_classes stores for which of the classes everything but the largest connected component needs to be
+ # removed
+ for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
+ results.append(pool.starmap_async(load_remove_save,
+ zip(output_filenames, output_filenames,
+ [for_which_classes] * len(output_filenames),
+ [min_valid_obj_size] * len(output_filenames))))
+ _ = [i.get() for i in results]
+ else:
+ print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
+ "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
+ "%s" % model)
+
+ pool.close()
+ pool.join()
+
+
+def predict_cases_fastest(model, list_of_lists, output_filenames, folds, num_threads_preprocessing,
+ num_threads_nifti_save, segs_from_prev_stage=None, do_tta=True, mixed_precision=True,
+ overwrite_existing=False, all_in_gpu=False, step_size=0.5,
+ checkpoint_name="model_final_checkpoint", disable_postprocessing: bool = False):
+ assert len(list_of_lists) == len(output_filenames)
+ if segs_from_prev_stage is not None: assert len(segs_from_prev_stage) == len(output_filenames)
+
+ pool = Pool(num_threads_nifti_save)
+ results = []
+
+ cleaned_output_files = []
+ for o in output_filenames:
+ dr, f = os.path.split(o)
+ if len(dr) > 0:
+ maybe_mkdir_p(dr)
+ if not f.endswith(".nii.gz"):
+ f, _ = os.path.splitext(f)
+ f = f + ".nii.gz"
+ cleaned_output_files.append(join(dr, f))
+
+ if not overwrite_existing:
+ print("number of cases:", len(list_of_lists))
+ not_done_idx = [i for i, j in enumerate(cleaned_output_files) if not isfile(j)]
+
+ cleaned_output_files = [cleaned_output_files[i] for i in not_done_idx]
+ list_of_lists = [list_of_lists[i] for i in not_done_idx]
+ if segs_from_prev_stage is not None:
+ segs_from_prev_stage = [segs_from_prev_stage[i] for i in not_done_idx]
+
+ print("number of cases that still need to be predicted:", len(cleaned_output_files))
+
+ print("emptying cuda cache")
+ torch.cuda.empty_cache()
+
+ print("loading parameters for folds,", folds)
+ trainer, params = load_model_and_checkpoint_files(model, folds, mixed_precision=mixed_precision,
+ checkpoint_name=checkpoint_name)
+
+ print("starting preprocessing generator")
+ preprocessing = preprocess_multithreaded(trainer, list_of_lists, cleaned_output_files, num_threads_preprocessing,
+ segs_from_prev_stage)
+
+ print("starting prediction...")
+ for preprocessed in preprocessing:
+ print("getting data from preprocessor")
+ output_filename, (d, dct) = preprocessed
+ print("got something")
+ if isinstance(d, str):
+ print("what I got is a string, so I need to load a file")
+ data = np.load(d)
+ os.remove(d)
+ d = data
+
+ # preallocate the output arrays
+ # same dtype as the return value in predict_preprocessed_data_return_seg_and_softmax (saves time)
+ all_softmax_outputs = np.zeros((len(params), trainer.num_classes, *d.shape[1:]), dtype=np.float16)
+ all_seg_outputs = np.zeros((len(params), *d.shape[1:]), dtype=int)
+ print("predicting", output_filename)
+
+ for i, p in enumerate(params):
+ trainer.load_checkpoint_ram(p, False)
+ res = trainer.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=do_tta,
+ mirror_axes=trainer.data_aug_params['mirror_axes'],
+ use_sliding_window=True,
+ step_size=step_size, use_gaussian=True,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=mixed_precision)
+ if len(params) > 1:
+ # otherwise we dont need this and we can save ourselves the time it takes to copy that
+ all_softmax_outputs[i] = res[1]
+ all_seg_outputs[i] = res[0]
+
+ if hasattr(trainer, 'regions_class_order'):
+ region_class_order = trainer.regions_class_order
+ else:
+ region_class_order = None
+ assert region_class_order is None, "predict_cases_fastest can only work with regular softmax predictions " \
+ "and is therefore unable to handle trainer classes with region_class_order"
+
+ print("aggregating predictions")
+ if len(params) > 1:
+ softmax_mean = np.mean(all_softmax_outputs, 0)
+ seg = softmax_mean.argmax(0)
+ else:
+ seg = all_seg_outputs[0]
+
+ print("applying transpose_backward")
+ transpose_forward = trainer.plans.get('transpose_forward')
+ if transpose_forward is not None:
+ transpose_backward = trainer.plans.get('transpose_backward')
+ seg = seg.transpose([i for i in transpose_backward])
+
+ print("initializing segmentation export")
+ results.append(pool.starmap_async(save_segmentation_nifti,
+ ((seg, output_filename, dct, 0, None),)
+ ))
+ print("done")
+
+ print("inference done. Now waiting for the segmentation export to finish...")
+ _ = [i.get() for i in results]
+ # now apply postprocessing
+ # first load the postprocessing properties if they are present. Else raise a well visible warning
+ if not disable_postprocessing:
+ results = []
+ pp_file = join(model, "postprocessing.json")
+ if isfile(pp_file):
+ print("postprocessing...")
+ shutil.copy(pp_file, os.path.dirname(output_filenames[0]))
+ # for_which_classes stores for which of the classes everything but the largest connected component needs to be
+ # removed
+ for_which_classes, min_valid_obj_size = load_postprocessing(pp_file)
+ results.append(pool.starmap_async(load_remove_save,
+ zip(output_filenames, output_filenames,
+ [for_which_classes] * len(output_filenames),
+ [min_valid_obj_size] * len(output_filenames))))
+ _ = [i.get() for i in results]
+ else:
+ print("WARNING! Cannot run postprocessing because the postprocessing file is missing. Make sure to run "
+ "consolidate_folds in the output folder of the model first!\nThe folder you need to run this in is "
+ "%s" % model)
+
+ pool.close()
+ pool.join()
+
+
+def check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities):
+ print("This model expects %d input modalities for each image" % expected_num_modalities)
+ files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True)
+
+ maybe_case_ids = np.unique([i[:-12] for i in files])
+
+ remaining = deepcopy(files)
+ missing = []
+
+ assert len(files) > 0, "input folder did not contain any images (expected to find .nii.gz file endings)"
+
+ # now check if all required files are present and that no unexpected files are remaining
+ for c in maybe_case_ids:
+ for n in range(expected_num_modalities):
+ expected_output_file = c + "_%04.0d.nii.gz" % n
+ if not isfile(join(input_folder, expected_output_file)):
+ missing.append(expected_output_file)
+ else:
+ remaining.remove(expected_output_file)
+
+ print("Found %d unique case ids, here are some examples:" % len(maybe_case_ids),
+ np.random.choice(maybe_case_ids, min(len(maybe_case_ids), 10)))
+ print("If they don't look right, make sure to double check your filenames. They must end with _0000.nii.gz etc")
+
+ if len(remaining) > 0:
+ print("found %d unexpected remaining files in the folder. Here are some examples:" % len(remaining),
+ np.random.choice(remaining, min(len(remaining), 10)))
+
+ if len(missing) > 0:
+ print("Some files are missing:")
+ print(missing)
+ raise RuntimeError("missing files in input_folder")
+
+ return maybe_case_ids
+
+
+def predict_from_folder(model: str, input_folder: str, output_folder: str, folds: Union[Tuple[int], List[int]],
+ save_npz: bool, num_threads_preprocessing: int, num_threads_nifti_save: int,
+ lowres_segmentations: Union[str, None],
+ part_id: int, num_parts: int, tta: bool, mixed_precision: bool = True,
+ overwrite_existing: bool = True, mode: str = 'normal', overwrite_all_in_gpu: bool = None,
+ step_size: float = 0.5, checkpoint_name: str = "model_final_checkpoint",
+ segmentation_export_kwargs: dict = None, disable_postprocessing: bool = False):
+ """
+ here we use the standard naming scheme to generate list_of_lists and output_files needed by predict_cases
+
+ :param model:
+ :param input_folder:
+ :param output_folder:
+ :param folds:
+ :param save_npz:
+ :param num_threads_preprocessing:
+ :param num_threads_nifti_save:
+ :param lowres_segmentations:
+ :param part_id:
+ :param num_parts:
+ :param tta:
+ :param mixed_precision:
+ :param overwrite_existing: if not None then it will be overwritten with whatever is in there. None is default (no overwrite)
+ :return:
+ """
+ maybe_mkdir_p(output_folder)
+ shutil.copy(join(model, 'plans.pkl'), output_folder)
+
+ assert isfile(join(model, "plans.pkl")), "Folder with saved model weights must contain a plans.pkl file"
+ expected_num_modalities = load_pickle(join(model, "plans.pkl"))['num_modalities']
+
+ # check input folder integrity
+ case_ids = check_input_folder_and_return_caseIDs(input_folder, expected_num_modalities)
+
+ output_files = [join(output_folder, i + ".nii.gz") for i in case_ids]
+ all_files = subfiles(input_folder, suffix=".nii.gz", join=False, sort=True)
+ list_of_lists = [[join(input_folder, i) for i in all_files if i[:len(j)].startswith(j) and
+ len(i) == (len(j) + 12)] for j in case_ids]
+
+ if lowres_segmentations is not None:
+ assert isdir(lowres_segmentations), "if lowres_segmentations is not None then it must point to a directory"
+ lowres_segmentations = [join(lowres_segmentations, i + ".nii.gz") for i in case_ids]
+ assert all([isfile(i) for i in lowres_segmentations]), "not all lowres_segmentations files are present. " \
+ "(I was searching for case_id.nii.gz in that folder)"
+ lowres_segmentations = lowres_segmentations[part_id::num_parts]
+ else:
+ lowres_segmentations = None
+
+ if mode == "normal":
+ if overwrite_all_in_gpu is None:
+ all_in_gpu = False
+ else:
+ all_in_gpu = overwrite_all_in_gpu
+
+ return predict_cases(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
+ save_npz, num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations, tta,
+ mixed_precision=mixed_precision, overwrite_existing=overwrite_existing,
+ all_in_gpu=all_in_gpu,
+ step_size=step_size, checkpoint_name=checkpoint_name,
+ segmentation_export_kwargs=segmentation_export_kwargs,
+ disable_postprocessing=disable_postprocessing)
+ elif mode == "fast":
+ if overwrite_all_in_gpu is None:
+ all_in_gpu = False
+ else:
+ all_in_gpu = overwrite_all_in_gpu
+
+ assert save_npz is False
+ return predict_cases_fast(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
+ num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations,
+ tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing,
+ all_in_gpu=all_in_gpu,
+ step_size=step_size, checkpoint_name=checkpoint_name,
+ segmentation_export_kwargs=segmentation_export_kwargs,
+ disable_postprocessing=disable_postprocessing)
+ elif mode == "fastest":
+ if overwrite_all_in_gpu is None:
+ all_in_gpu = False
+ else:
+ all_in_gpu = overwrite_all_in_gpu
+
+ assert save_npz is False
+ return predict_cases_fastest(model, list_of_lists[part_id::num_parts], output_files[part_id::num_parts], folds,
+ num_threads_preprocessing, num_threads_nifti_save, lowres_segmentations,
+ tta, mixed_precision=mixed_precision, overwrite_existing=overwrite_existing,
+ all_in_gpu=all_in_gpu,
+ step_size=step_size, checkpoint_name=checkpoint_name,
+ disable_postprocessing=disable_postprocessing)
+ else:
+ raise ValueError("unrecognized mode. Must be normal, fast or fastest")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct"
+ " order (same as training). Files must be named "
+ "CASENAME_XXXX.nii.gz where XXXX is the modality "
+ "identifier (0000, 0001, etc)", required=True)
+ parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions")
+ parser.add_argument('-m', '--model_output_folder',
+ help='model output folder. Will automatically discover the folds '
+ 'that were '
+ 'run and use those as an ensemble', required=True)
+ parser.add_argument('-f', '--folds', nargs='+', default='None', help="folds to use for prediction. Default is None "
+ "which means that folds will be detected "
+ "automatically in the model output folder")
+ parser.add_argument('-z', '--save_npz', required=False, action='store_true', help="use this if you want to ensemble"
+ " these predictions with those of"
+ " other models. Softmax "
+ "probabilities will be saved as "
+ "compresed numpy arrays in "
+ "output_folder and can be merged "
+ "between output_folders with "
+ "merge_predictions.py")
+ parser.add_argument('-l', '--lowres_segmentations', required=False, default='None', help="if model is the highres "
+ "stage of the cascade then you need to use -l to specify where the segmentations of the "
+ "corresponding lowres unet are. Here they are required to do a prediction")
+ parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of "
+ "the folder over several GPUs. If you "
+ "want to use n GPUs to predict this "
+ "folder you need to run this command "
+ "n times with --part_id=0, ... n-1 and "
+ "--num_parts=n (each with a different "
+ "GPU (for example via "
+ "CUDA_VISIBLE_DEVICES=X)")
+ parser.add_argument("--num_parts", type=int, required=False, default=1,
+ help="Used to parallelize the prediction of "
+ "the folder over several GPUs. If you "
+ "want to use n GPUs to predict this "
+ "folder you need to run this command "
+ "n times with --part_id=0, ... n-1 and "
+ "--num_parts=n (each with a different "
+ "GPU (via "
+ "CUDA_VISIBLE_DEVICES=X)")
+ parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help=
+ "Determines many background processes will be used for data preprocessing. Reduce this if you "
+ "run into out of memory (RAM) problems. Default: 6")
+ parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help=
+ "Determines many background processes will be used for segmentation export. Reduce this if you "
+ "run into out of memory (RAM) problems. Default: 2")
+ parser.add_argument("--tta", required=False, type=int, default=1, help="Set to 0 to disable test time data "
+ "augmentation (speedup of factor "
+ "4(2D)/8(3D)), "
+ "lower quality segmentations")
+ parser.add_argument("--overwrite_existing", required=False, type=int, default=1, help="Set this to 0 if you need "
+ "to resume a previous "
+ "prediction. Default: 1 "
+ "(=existing segmentations "
+ "in output_folder will be "
+ "overwritten)")
+ parser.add_argument("--mode", type=str, default="normal", required=False)
+ parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True")
+ parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch")
+ # parser.add_argument("--interp_order", required=False, default=3, type=int,
+ # help="order of interpolation for segmentations, has no effect if mode=fastest")
+ # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
+ # help="order of interpolation along z is z is done differently")
+ # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
+ # help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest")
+ parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False,
+ help='Predictions are done with mixed precision by default. This improves speed and reduces '
+ 'the required vram. If you want to disable mixed precision you can set this flag. Note '
+ 'that this is not recommended (mixed precision is ~2x faster!)')
+
+ args = parser.parse_args()
+ input_folder = args.input_folder
+ output_folder = args.output_folder
+ part_id = args.part_id
+ num_parts = args.num_parts
+ model = args.model_output_folder
+ folds = args.folds
+ save_npz = args.save_npz
+ lowres_segmentations = args.lowres_segmentations
+ num_threads_preprocessing = args.num_threads_preprocessing
+ num_threads_nifti_save = args.num_threads_nifti_save
+ tta = args.tta
+ step_size = args.step_size
+
+ # interp_order = args.interp_order
+ # interp_order_z = args.interp_order_z
+ # force_separate_z = args.force_separate_z
+
+ # if force_separate_z == "None":
+ # force_separate_z = None
+ # elif force_separate_z == "False":
+ # force_separate_z = False
+ # elif force_separate_z == "True":
+ # force_separate_z = True
+ # else:
+ # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
+
+ overwrite = args.overwrite_existing
+ mode = args.mode
+ all_in_gpu = args.all_in_gpu
+
+ if lowres_segmentations == "None":
+ lowres_segmentations = None
+
+ if isinstance(folds, list):
+ if folds[0] == 'all' and len(folds) == 1:
+ pass
+ else:
+ folds = [int(i) for i in folds]
+ elif folds == "None":
+ folds = None
+ else:
+ raise ValueError("Unexpected value for argument folds")
+
+ if tta == 0:
+ tta = False
+ elif tta == 1:
+ tta = True
+ else:
+ raise ValueError("Unexpected value for tta, Use 1 or 0")
+
+ if overwrite == 0:
+ overwrite = False
+ elif overwrite == 1:
+ overwrite = True
+ else:
+ raise ValueError("Unexpected value for overwrite, Use 1 or 0")
+
+ assert all_in_gpu in ['None', 'False', 'True']
+ if all_in_gpu == "None":
+ all_in_gpu = None
+ elif all_in_gpu == "True":
+ all_in_gpu = True
+ elif all_in_gpu == "False":
+ all_in_gpu = False
+
+ predict_from_folder(model, input_folder, output_folder, folds, save_npz, num_threads_preprocessing,
+ num_threads_nifti_save, lowres_segmentations, part_id, num_parts, tta,
+ mixed_precision=not args.disable_mixed_precision,
+ overwrite_existing=overwrite, mode=mode, overwrite_all_in_gpu=all_in_gpu, step_size=step_size)
diff --git a/nnunet/inference/predict_simple.py b/nnunet/inference/predict_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..31a4b8aa3b8c988af020a241f9461661098d0116
--- /dev/null
+++ b/nnunet/inference/predict_simple.py
@@ -0,0 +1,225 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+import torch
+
+from nnunet.inference.predict import predict_from_folder
+from nnunet.paths import default_plans_identifier, network_training_output_dir, default_cascade_trainer, default_trainer
+from batchgenerators.utilities.file_and_folder_operations import join, isdir
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-i", '--input_folder', help="Must contain all modalities for each patient in the correct"
+ " order (same as training). Files must be named "
+ "CASENAME_XXXX.nii.gz where XXXX is the modality "
+ "identifier (0000, 0001, etc)", required=True)
+ parser.add_argument('-o', "--output_folder", required=True, help="folder for saving predictions")
+ parser.add_argument('-t', '--task_name', help='task name or task ID, required.',
+ default=default_plans_identifier, required=True)
+
+ parser.add_argument('-tr', '--trainer_class_name',
+ help='Name of the nnUNetTrainer used for 2D U-Net, full resolution 3D U-Net and low resolution '
+ 'U-Net. The default is %s. If you are running inference with the cascade and the folder '
+ 'pointed to by --lowres_segmentations does not contain the segmentation maps generated by '
+ 'the low resolution U-Net then the low resolution segmentation maps will be automatically '
+ 'generated. For this case, make sure to set the trainer class here that matches your '
+ '--cascade_trainer_class_name (this part can be ignored if defaults are used).'
+ % default_trainer,
+ required=False,
+ default=default_trainer)
+ parser.add_argument('-ctr', '--cascade_trainer_class_name',
+ help="Trainer class name used for predicting the 3D full resolution U-Net part of the cascade."
+ "Default is %s" % default_cascade_trainer, required=False,
+ default=default_cascade_trainer)
+
+ parser.add_argument('-m', '--model', help="2d, 3d_lowres, 3d_fullres or 3d_cascade_fullres. Default: 3d_fullres",
+ default="3d_fullres", required=False)
+
+ parser.add_argument('-p', '--plans_identifier', help='do not touch this unless you know what you are doing',
+ default=default_plans_identifier, required=False)
+
+ parser.add_argument('-f', '--folds', nargs='+', default='None',
+ help="folds to use for prediction. Default is None which means that folds will be detected "
+ "automatically in the model output folder")
+
+ parser.add_argument('-z', '--save_npz', required=False, action='store_true',
+ help="use this if you want to ensemble these predictions with those of other models. Softmax "
+ "probabilities will be saved as compressed numpy arrays in output_folder and can be "
+ "merged between output_folders with nnUNet_ensemble_predictions")
+
+ parser.add_argument('-l', '--lowres_segmentations', required=False, default='None',
+ help="if model is the highres stage of the cascade then you can use this folder to provide "
+ "predictions from the low resolution 3D U-Net. If this is left at default, the "
+ "predictions will be generated automatically (provided that the 3D low resolution U-Net "
+ "network weights are present")
+
+ parser.add_argument("--part_id", type=int, required=False, default=0, help="Used to parallelize the prediction of "
+ "the folder over several GPUs. If you "
+ "want to use n GPUs to predict this "
+ "folder you need to run this command "
+ "n times with --part_id=0, ... n-1 and "
+ "--num_parts=n (each with a different "
+ "GPU (for example via "
+ "CUDA_VISIBLE_DEVICES=X)")
+
+ parser.add_argument("--num_parts", type=int, required=False, default=1,
+ help="Used to parallelize the prediction of "
+ "the folder over several GPUs. If you "
+ "want to use n GPUs to predict this "
+ "folder you need to run this command "
+ "n times with --part_id=0, ... n-1 and "
+ "--num_parts=n (each with a different "
+ "GPU (via "
+ "CUDA_VISIBLE_DEVICES=X)")
+
+ parser.add_argument("--num_threads_preprocessing", required=False, default=6, type=int, help=
+ "Determines many background processes will be used for data preprocessing. Reduce this if you "
+ "run into out of memory (RAM) problems. Default: 6")
+
+ parser.add_argument("--num_threads_nifti_save", required=False, default=2, type=int, help=
+ "Determines many background processes will be used for segmentation export. Reduce this if you "
+ "run into out of memory (RAM) problems. Default: 2")
+
+ parser.add_argument("--disable_tta", required=False, default=False, action="store_true",
+ help="set this flag to disable test time data augmentation via mirroring. Speeds up inference "
+ "by roughly factor 4 (2D) or 8 (3D)")
+
+ parser.add_argument("--overwrite_existing", required=False, default=False, action="store_true",
+ help="Set this flag if the target folder contains predictions that you would like to overwrite")
+
+ parser.add_argument("--mode", type=str, default="normal", required=False, help="Hands off!")
+ parser.add_argument("--all_in_gpu", type=str, default="None", required=False, help="can be None, False or True. "
+ "Do not touch.")
+ parser.add_argument("--step_size", type=float, default=0.5, required=False, help="don't touch")
+ # parser.add_argument("--interp_order", required=False, default=3, type=int,
+ # help="order of interpolation for segmentations, has no effect if mode=fastest. Do not touch this.")
+ # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
+ # help="order of interpolation along z is z is done differently. Do not touch this.")
+ # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
+ # help="force_separate_z resampling. Can be None, True or False, has no effect if mode=fastest. "
+ # "Do not touch this.")
+ parser.add_argument('-chk',
+ help='checkpoint name, default: model_final_checkpoint',
+ required=False,
+ default='model_final_checkpoint')
+ parser.add_argument('--disable_mixed_precision', default=False, action='store_true', required=False,
+ help='Predictions are done with mixed precision by default. This improves speed and reduces '
+ 'the required vram. If you want to disable mixed precision you can set this flag. Note '
+ 'that this is not recommended (mixed precision is ~2x faster!)')
+
+ args = parser.parse_args()
+ input_folder = args.input_folder
+ output_folder = args.output_folder
+ part_id = args.part_id
+ num_parts = args.num_parts
+ folds = args.folds
+ save_npz = args.save_npz
+ lowres_segmentations = args.lowres_segmentations
+ num_threads_preprocessing = args.num_threads_preprocessing
+ num_threads_nifti_save = args.num_threads_nifti_save
+ disable_tta = args.disable_tta
+ step_size = args.step_size
+ # interp_order = args.interp_order
+ # interp_order_z = args.interp_order_z
+ # force_separate_z = args.force_separate_z
+ overwrite_existing = args.overwrite_existing
+ mode = args.mode
+ all_in_gpu = args.all_in_gpu
+ model = args.model
+ trainer_class_name = args.trainer_class_name
+ cascade_trainer_class_name = args.cascade_trainer_class_name
+
+ task_name = args.task_name
+
+ if not task_name.startswith("Task"):
+ task_id = int(task_name)
+ task_name = convert_id_to_task_name(task_id)
+
+ assert model in ["2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"], "-m must be 2d, 3d_lowres, 3d_fullres or " \
+ "3d_cascade_fullres"
+
+ # if force_separate_z == "None":
+ # force_separate_z = None
+ # elif force_separate_z == "False":
+ # force_separate_z = False
+ # elif force_separate_z == "True":
+ # force_separate_z = True
+ # else:
+ # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
+
+ if lowres_segmentations == "None":
+ lowres_segmentations = None
+
+ if isinstance(folds, list):
+ if folds[0] == 'all' and len(folds) == 1:
+ pass
+ else:
+ folds = [int(i) for i in folds]
+ elif folds == "None":
+ folds = None
+ else:
+ raise ValueError("Unexpected value for argument folds")
+
+ assert all_in_gpu in ['None', 'False', 'True']
+ if all_in_gpu == "None":
+ all_in_gpu = None
+ elif all_in_gpu == "True":
+ all_in_gpu = True
+ elif all_in_gpu == "False":
+ all_in_gpu = False
+
+ # we need to catch the case where model is 3d cascade fullres and the low resolution folder has not been set.
+ # In that case we need to try and predict with 3d low res first
+ if model == "3d_cascade_fullres" and lowres_segmentations is None:
+ print("lowres_segmentations is None. Attempting to predict 3d_lowres first...")
+ assert part_id == 0 and num_parts == 1, "if you don't specify a --lowres_segmentations folder for the " \
+ "inference of the cascade, custom values for part_id and num_parts " \
+ "are not supported. If you wish to have multiple parts, please " \
+ "run the 3d_lowres inference first (separately)"
+ model_folder_name = join(network_training_output_dir, "3d_lowres", task_name, trainer_class_name + "__" +
+ args.plans_identifier)
+ assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name
+ lowres_output_folder = join(output_folder, "3d_lowres_predictions")
+ predict_from_folder(model_folder_name, input_folder, lowres_output_folder, folds, False,
+ num_threads_preprocessing, num_threads_nifti_save, None, part_id, num_parts, not disable_tta,
+ overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,
+ mixed_precision=not args.disable_mixed_precision,
+ step_size=step_size)
+ lowres_segmentations = lowres_output_folder
+ torch.cuda.empty_cache()
+ print("3d_lowres done")
+
+ if model == "3d_cascade_fullres":
+ trainer = cascade_trainer_class_name
+ else:
+ trainer = trainer_class_name
+
+ model_folder_name = join(network_training_output_dir, model, task_name, trainer + "__" +
+ args.plans_identifier)
+ print("using model stored in ", model_folder_name)
+ assert isdir(model_folder_name), "model output folder not found. Expected: %s" % model_folder_name
+
+ predict_from_folder(model_folder_name, input_folder, output_folder, folds, save_npz, num_threads_preprocessing,
+ num_threads_nifti_save, lowres_segmentations, part_id, num_parts, not disable_tta,
+ overwrite_existing=overwrite_existing, mode=mode, overwrite_all_in_gpu=all_in_gpu,
+ mixed_precision=not args.disable_mixed_precision,
+ step_size=step_size, checkpoint_name=args.chk)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/inference/pretrained_models/__init__.py b/nnunet/inference/pretrained_models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/inference/pretrained_models/collect_pretrained_models.py b/nnunet/inference/pretrained_models/collect_pretrained_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e382cc323438b6307b248f2647a323bcd9700f
--- /dev/null
+++ b/nnunet/inference/pretrained_models/collect_pretrained_models.py
@@ -0,0 +1,271 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import zipfile
+from multiprocessing.pool import Pool
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+from nnunet.paths import default_cascade_trainer, default_plans_identifier, default_trainer, network_training_output_dir
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+from subprocess import call
+
+
+def copy_fold(in_folder: str, out_folder: str):
+ shutil.copy(join(in_folder, "debug.json"), join(out_folder, "debug.json"))
+ shutil.copy(join(in_folder, "model_final_checkpoint.model"), join(out_folder, "model_final_checkpoint.model"))
+ shutil.copy(join(in_folder, "model_final_checkpoint.model.pkl"),
+ join(out_folder, "model_final_checkpoint.model.pkl"))
+ shutil.copy(join(in_folder, "progress.png"), join(out_folder, "progress.png"))
+ if isfile(join(in_folder, "network_architecture.pdf")):
+ shutil.copy(join(in_folder, "network_architecture.pdf"), join(out_folder, "network_architecture.pdf"))
+
+
+def copy_model(directory: str, output_directory: str):
+ """
+
+ :param directory: must have the 5 fold_X subfolders as well as a postprocessing.json and plans.pkl
+ :param output_directory:
+ :return:
+ """
+ expected_folders = ["fold_%d" % i for i in range(5)]
+ assert all([isdir(join(directory, i)) for i in expected_folders]), "not all folds present"
+
+ assert isfile(join(directory, "plans.pkl")), "plans.pkl missing"
+ assert isfile(join(directory, "postprocessing.json")), "postprocessing.json missing"
+
+ for e in expected_folders:
+ maybe_mkdir_p(join(output_directory, e))
+ copy_fold(join(directory, e), join(output_directory, e))
+
+ shutil.copy(join(directory, "plans.pkl"), join(output_directory, "plans.pkl"))
+ shutil.copy(join(directory, "postprocessing.json"), join(output_directory, "postprocessing.json"))
+
+
+def copy_pretrained_models_for_task(task_name: str, output_directory: str,
+ models: tuple = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ nnunet_trainer=default_trainer,
+ nnunet_trainer_cascade=default_cascade_trainer,
+ plans_identifier=default_plans_identifier):
+ trainer_output_dir = nnunet_trainer + "__" + plans_identifier
+ trainer_output_dir_cascade = nnunet_trainer_cascade + "__" + plans_identifier
+
+ for m in models:
+ to = trainer_output_dir_cascade if m == "3d_cascade_fullres" else trainer_output_dir
+ expected_output_folder = join(network_training_output_dir, m, task_name, to)
+ if not isdir(expected_output_folder):
+ if m == "3d_lowres" or m == "3d_cascade_fullres":
+ print("Task", task_name, "does not seem to have the cascade")
+ continue
+ else:
+ raise RuntimeError("missing folder! %s" % expected_output_folder)
+ output_here = join(output_directory, m, task_name, to)
+ maybe_mkdir_p(output_here)
+ copy_model(expected_output_folder, output_here)
+
+
+def check_if_valid(ensemble: str, valid_models, valid_trainers, valid_plans):
+ ensemble = ensemble[len("ensemble_"):]
+ mb1, mb2 = ensemble.split("--")
+ c1, tr1, p1 = mb1.split("__")
+ c2, tr2, p2 = mb2.split("__")
+ if c1 not in valid_models: return False
+ if c2 not in valid_models: return False
+ if tr1 not in valid_trainers: return False
+ if tr2 not in valid_trainers: return False
+ if p1 not in valid_plans: return False
+ if p2 not in valid_plans: return False
+ return True
+
+
+def copy_ensembles(taskname, output_folder, valid_models=('2d', '3d_fullres', '3d_lowres', '3d_cascade_fullres'),
+ valid_trainers=(default_trainer, default_cascade_trainer),
+ valid_plans=(default_plans_identifier,)):
+ ensemble_dir = join(network_training_output_dir, 'ensembles', taskname)
+ if not isdir(ensemble_dir):
+ print("No ensemble directory found for task", taskname)
+ return
+ subd = subdirs(ensemble_dir, join=False)
+ valid = []
+ for s in subd:
+ v = check_if_valid(s, valid_models, valid_trainers, valid_plans)
+ if v:
+ valid.append(s)
+ output_ensemble = join(output_folder, 'ensembles', taskname)
+ maybe_mkdir_p(output_ensemble)
+ for v in valid:
+ this_output = join(output_ensemble, v)
+ maybe_mkdir_p(this_output)
+ shutil.copy(join(ensemble_dir, v, 'postprocessing.json'), this_output)
+
+
+def compress_everything(output_base, num_processes=8):
+ p = Pool(num_processes)
+ tasks = subfolders(output_base, join=False)
+ tasknames = [i.split('/')[-1] for i in tasks]
+ args = []
+ for t, tn in zip(tasks, tasknames):
+ args.append((join(output_base, tn + ".zip"), join(output_base, t)))
+ p.starmap(compress_folder, args)
+ p.close()
+ p.join()
+
+
+def compress_folder(zip_file, folder):
+ """inspired by https://stackoverflow.com/questions/1855095/how-to-create-a-zip-archive-of-a-directory-in-python"""
+ zipf = zipfile.ZipFile(zip_file, 'w', zipfile.ZIP_DEFLATED)
+ for root, dirs, files in os.walk(folder):
+ for file in files:
+ zipf.write(join(root, file), os.path.relpath(join(root, file), folder))
+
+
+def export_one_task(taskname, models, output_folder, nnunet_trainer=default_trainer,
+ nnunet_trainer_cascade=default_cascade_trainer,
+ plans_identifier=default_plans_identifier):
+ copy_pretrained_models_for_task(taskname, output_folder, models, nnunet_trainer, nnunet_trainer_cascade,
+ plans_identifier)
+ copy_ensembles(taskname, output_folder, models, (nnunet_trainer, nnunet_trainer_cascade), (plans_identifier,))
+ compress_folder(join(output_folder, taskname + '.zip'), join(output_folder, taskname))
+
+
+def export_pretrained_model(task_name: str, output_file: str,
+ models: tuple = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"),
+ nnunet_trainer=default_trainer,
+ nnunet_trainer_cascade=default_cascade_trainer,
+ plans_identifier=default_plans_identifier,
+ folds=(0, 1, 2, 3, 4), strict=True):
+ zipf = zipfile.ZipFile(output_file, 'w', zipfile.ZIP_DEFLATED)
+ trainer_output_dir = nnunet_trainer + "__" + plans_identifier
+ trainer_output_dir_cascade = nnunet_trainer_cascade + "__" + plans_identifier
+
+ for m in models:
+ to = trainer_output_dir_cascade if m == "3d_cascade_fullres" else trainer_output_dir
+ expected_output_folder = join(network_training_output_dir, m, task_name, to)
+ if not isdir(expected_output_folder):
+ if strict:
+ raise RuntimeError("Task %s is missing the model %s" % (task_name, m))
+ else:
+ continue
+
+ expected_folders = ["fold_%d" % i if i != 'all' else i for i in folds]
+ assert all([isdir(join(expected_output_folder, i)) for i in expected_folders]), "not all requested folds " \
+ "present, " \
+ "Task %s model %s" % \
+ (task_name, m)
+
+ assert isfile(join(expected_output_folder, "plans.pkl")), "plans.pkl missing, Task %s model %s" % (task_name, m)
+
+ for e in expected_folders:
+ zipf.write(join(expected_output_folder, e, "debug.json"),
+ os.path.relpath(join(expected_output_folder, e, "debug.json"),
+ network_training_output_dir))
+ zipf.write(join(expected_output_folder, e, "model_final_checkpoint.model"),
+ os.path.relpath(join(expected_output_folder, e, "model_final_checkpoint.model"),
+ network_training_output_dir))
+ zipf.write(join(expected_output_folder, e, "model_final_checkpoint.model.pkl"),
+ os.path.relpath(join(expected_output_folder, e, "model_final_checkpoint.model.pkl"),
+ network_training_output_dir))
+ zipf.write(join(expected_output_folder, e, "progress.png"),
+ os.path.relpath(join(expected_output_folder, e, "progress.png"), network_training_output_dir))
+ if isfile(join(expected_output_folder, e, "network_architecture.pdf")):
+ zipf.write(join(expected_output_folder, e, "network_architecture.pdf"),
+ os.path.relpath(join(expected_output_folder, e, "network_architecture.pdf"),
+ network_training_output_dir))
+
+ zipf.write(join(expected_output_folder, "plans.pkl"),
+ os.path.relpath(join(expected_output_folder, "plans.pkl"), network_training_output_dir))
+ if not isfile(join(expected_output_folder, "postprocessing.json")):
+ if strict:
+ raise RuntimeError('postprocessing.json missing. Run nnUNet_determine_postprocessing or disable strict')
+ else:
+ print('WARNING: postprocessing.json missing')
+ else:
+ zipf.write(join(expected_output_folder, "postprocessing.json"),
+ os.path.relpath(join(expected_output_folder, "postprocessing.json"), network_training_output_dir))
+
+ ensemble_dir = join(network_training_output_dir, 'ensembles', task_name)
+ if not isdir(ensemble_dir):
+ print("No ensemble directory found for task", task_name)
+ return
+ subd = subdirs(ensemble_dir, join=False)
+ valid = []
+ for s in subd:
+ v = check_if_valid(s, models, (nnunet_trainer, nnunet_trainer_cascade), (plans_identifier))
+ if v:
+ valid.append(s)
+ for v in valid:
+ zipf.write(join(ensemble_dir, v, 'postprocessing.json'),
+ os.path.relpath(join(ensemble_dir, v, 'postprocessing.json'),
+ network_training_output_dir))
+ zipf.close()
+
+
+def export_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser(description="Use this script to export models to a zip file for sharing with "
+ "others. You can upload the zip file and then either share the url "
+ "for usage with nnUNet_download_pretrained_model_by_url, or share the "
+ "zip for usage with nnUNet_install_pretrained_model_from_zip")
+ parser.add_argument('-t', type=str, help='task name or task id')
+ parser.add_argument('-o', type=str, help='output file name. Should end with .zip')
+ parser.add_argument('-m', nargs='+',
+ help='list of model configurations. Default: 2d 3d_lowres 3d_fullres 3d_cascade_fullres. Must '
+ 'be adapted to fit the available models of a task',
+ default=("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres"), required=False)
+ parser.add_argument('-tr', type=str, help='trainer class used for 2d 3d_lowres and 3d_fullres. '
+ 'Default: %s' % default_trainer, required=False, default=default_trainer)
+ parser.add_argument('-trc', type=str, help='trainer class used for 3d_cascade_fullres. '
+ 'Default: %s' % default_cascade_trainer, required=False,
+ default=default_cascade_trainer)
+ parser.add_argument('-pl', type=str, help='nnunet plans identifier. Default: %s' % default_plans_identifier,
+ required=False, default=default_plans_identifier)
+ parser.add_argument('--disable_strict', action='store_true', help='set this if you want to allow skipping '
+ 'missing things', required=False)
+ parser.add_argument('-f', nargs='+', help='Folds. Default: 0 1 2 3 4', required=False, default=[0, 1, 2, 3, 4])
+ args = parser.parse_args()
+
+ folds = args.f
+ folds = [int(i) if i != 'all' else i for i in folds]
+
+ taskname = args.t
+ if taskname.startswith("Task"):
+ pass
+ else:
+ try:
+ taskid = int(taskname)
+ except Exception as e:
+ print('-t must be either a Task name (TaskXXX_YYY) or a task id (integer)')
+ raise e
+ taskname = convert_id_to_task_name(taskid)
+
+ export_pretrained_model(taskname, args.o, args.m, args.tr, args.trc, args.pl, strict=not args.disable_strict,
+ folds=folds)
+
+
+def export_for_paper():
+ output_base = "/media/fabian/DeepLearningData/nnunet_trained_models"
+ task_ids = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 17, 24, 27, 29, 35, 48, 55, 61, 38]
+ for t in task_ids:
+ if t == 61:
+ models = ("3d_fullres",)
+ else:
+ models = ("2d", "3d_lowres", "3d_fullres", "3d_cascade_fullres")
+ taskname = convert_id_to_task_name(t)
+ print(taskname)
+ output_folder = join(output_base, taskname)
+ maybe_mkdir_p(output_folder)
+ copy_pretrained_models_for_task(taskname, output_folder, models)
+ copy_ensembles(taskname, output_folder)
+ compress_everything(output_base, 8)
diff --git a/nnunet/inference/pretrained_models/download_pretrained_model.py b/nnunet/inference/pretrained_models/download_pretrained_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebed6d6dc3b4d64db1f1320bf328cbec4808ae4a
--- /dev/null
+++ b/nnunet/inference/pretrained_models/download_pretrained_model.py
@@ -0,0 +1,367 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Optional
+
+import zipfile
+from time import time
+
+import requests
+from batchgenerators.utilities.file_and_folder_operations import join, isfile
+
+from nnunet.paths import network_training_output_dir
+
+
+def get_available_models():
+ available_models = {
+ "Task001_BrainTumour": {
+ 'description': "Brain Tumor Segmentation. \n"
+ "Segmentation targets are edema, enhancing tumor and necrosis, \n"
+ "Input modalities are 0: FLAIR, 1: T1, 2: T1 with contrast agent, 3: T2. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task001_BrainTumour.zip?download=1"
+ },
+ "Task002_Heart": {
+ 'description': "Left Atrium Segmentation. \n"
+ "Segmentation target is the left atrium, \n"
+ "Input modalities are 0: MRI. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task002_Heart.zip?download=1"
+ },
+ "Task003_Liver": {
+ 'description': "Liver and Liver Tumor Segmentation. \n"
+ "Segmentation targets are liver and tumors, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task003_Liver.zip?download=1"
+ },
+ "Task004_Hippocampus": {
+ 'description': "Hippocampus Segmentation. \n"
+ "Segmentation targets posterior and anterior parts of the hippocampus, \n"
+ "Input modalities are 0: MRI. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task004_Hippocampus.zip?download=1"
+ },
+ "Task005_Prostate": {
+ 'description': "Prostate Segmentation. \n"
+ "Segmentation targets are peripheral and central zone, \n"
+ "Input modalities are 0: T2, 1: ADC. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4485926/files/Task005_Prostate.zip?download=1"
+ },
+ "Task006_Lung": {
+ 'description': "Lung Nodule Segmentation. \n"
+ "Segmentation target are lung nodules, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task006_Lung.zip?download=1"
+ },
+ "Task007_Pancreas": {
+ 'description': "Pancreas Segmentation. \n"
+ "Segmentation targets are pancras and pancreas tumor, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task007_Pancreas.zip?download=1"
+ },
+ "Task008_HepaticVessel": {
+ 'description': "Hepatic Vessel Segmentation. \n"
+ "Segmentation targets are hepatic vesels and liver tumors, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task008_HepaticVessel.zip?download=1"
+ },
+ "Task009_Spleen": {
+ 'description': "Spleen Segmentation. \n"
+ "Segmentation target is the spleen, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task009_Spleen.zip?download=1"
+ },
+ "Task010_Colon": {
+ 'description': "Colon Cancer Segmentation. \n"
+ "Segmentation target are colon caner primaries, \n"
+ "Input modalities are 0: CT scan. \n"
+ "Also see Medical Segmentation Decathlon, http://medicaldecathlon.com/",
+ 'url': "https://zenodo.org/record/4003545/files/Task010_Colon.zip?download=1"
+ },
+ "Task017_AbdominalOrganSegmentation": {
+ 'description': "Multi-Atlas Labeling Beyond the Cranial Vault - Abdomen. \n"
+ "Segmentation targets are thirteen different abdominal organs, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see https://www.synapse.org/#!Synapse:syn3193805/wiki/217754",
+ 'url': "https://zenodo.org/record/4003545/files/Task017_AbdominalOrganSegmentation.zip?download=1"
+ },
+ "Task024_Promise": {
+ 'description': "Prostate MR Image Segmentation 2012. \n"
+ "Segmentation target is the prostate, \n"
+ "Input modalities are 0: T2. \n"
+ "Also see https://promise12.grand-challenge.org/",
+ 'url': "https://zenodo.org/record/4003545/files/Task024_Promise.zip?download=1"
+ },
+ "Task027_ACDC": {
+ 'description': "Automatic Cardiac Diagnosis Challenge. \n"
+ "Segmentation targets are right ventricle, left ventricular cavity and left myocardium, \n"
+ "Input modalities are 0: cine MRI. \n"
+ "Also see https://acdc.creatis.insa-lyon.fr/",
+ 'url': "https://zenodo.org/record/4003545/files/Task027_ACDC.zip?download=1"
+ },
+ "Task029_LiTS": {
+ 'description': "Liver and Liver Tumor Segmentation Challenge. \n"
+ "Segmentation targets are liver and liver tumors, \n"
+ "Input modalities are 0: abdominal CT scan. \n"
+ "Also see https://competitions.codalab.org/competitions/17094",
+ 'url': "https://zenodo.org/record/4003545/files/Task029_LITS.zip?download=1"
+ },
+ "Task035_ISBILesionSegmentation": {
+ 'description': "Longitudinal multiple sclerosis lesion segmentation Challenge. \n"
+ "Segmentation target is MS lesions, \n"
+ "input modalities are 0: FLAIR, 1: MPRAGE, 2: proton density, 3: T2. \n"
+ "Also see https://smart-stats-tools.org/lesion-challenge",
+ 'url': "https://zenodo.org/record/4003545/files/Task035_ISBILesionSegmentation.zip?download=1"
+ },
+ "Task038_CHAOS_Task_3_5_Variant2": {
+ 'description': "CHAOS - Combined (CT-MR) Healthy Abdominal Organ Segmentation Challenge (Task 3 & 5). \n"
+ "Segmentation targets are left and right kidney, liver, spleen, \n"
+ "Input modalities are 0: T1 in-phase, T1 out-phase, T2 (can be any of those)\n"
+ "Also see https://chaos.grand-challenge.org/",
+ 'url': "https://zenodo.org/record/4003545/files/Task038_CHAOS_Task_3_5_Variant2.zip?download=1"
+ },
+ "Task048_KiTS_clean": {
+ 'description': "Kidney and Kidney Tumor Segmentation Challenge. "
+ "Segmentation targets kidney and kidney tumors, "
+ "Input modalities are 0: abdominal CT scan. "
+ "Also see https://kits19.grand-challenge.org/",
+ 'url': "https://zenodo.org/record/4003545/files/Task048_KiTS_clean.zip?download=1"
+ },
+ "Task055_SegTHOR": {
+ 'description': "SegTHOR: Segmentation of THoracic Organs at Risk in CT images. \n"
+ "Segmentation targets are aorta, esophagus, heart and trachea, \n"
+ "Input modalities are 0: CT scan. \n"
+ "Also see https://competitions.codalab.org/competitions/21145",
+ 'url': "https://zenodo.org/record/4003545/files/Task055_SegTHOR.zip?download=1"
+ },
+ "Task061_CREMI": {
+ 'description': "MICCAI Challenge on Circuit Reconstruction from Electron Microscopy Images (Synaptic Cleft segmentation task). \n"
+ "Segmentation target is synaptic clefts, \n"
+ "Input modalities are 0: serial section transmission electron microscopy of neural tissue. \n"
+ "Also see https://cremi.org/",
+ 'url': "https://zenodo.org/record/4003545/files/Task061_CREMI.zip?download=1"
+ },
+ "Task075_Fluo_C3DH_A549_ManAndSim": {
+ 'description': "Fluo-C3DH-A549-SIM and Fluo-C3DH-A549 datasets of the cell tracking challenge. Segmentation target are C3DH cells in fluorescence microscopy images.\n"
+ "Input modalities are 0: fluorescence_microscopy\n"
+ "Also see http://celltrackingchallenge.net/",
+ 'url': "https://zenodo.org/record/4003545/files/Task075_Fluo_C3DH_A549_ManAndSim.zip?download=1"
+ },
+ "Task076_Fluo_N3DH_SIM": {
+ 'description': "Fluo-N3DH-SIM dataset of the cell tracking challenge. Segmentation target are N3DH cells and cell borders in fluorescence microscopy images.\n"
+ "Input modalities are 0: fluorescence_microscopy\n"
+ "Also see http://celltrackingchallenge.net/\n"
+ "Note that the segmentation output of the models are cell center and cell border. These outputs mus tbe converted to an instance segmentation for the challenge. \n"
+ "See https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task076_Fluo_N3DH_SIM.py",
+ 'url': "https://zenodo.org/record/4003545/files/Task076_Fluo_N3DH_SIM.zip?download=1"
+ },
+ "Task082_BraTS2020": {
+ 'description': "Brain tumor segmentation challenge 2020 (BraTS)\n"
+ "Segmentation targets are 0: background, 1: edema, 2: necrosis, 3: enhancing tumor\n"
+ "Input modalities are 0: T1, 1: T1ce, 2: T2, 3: FLAIR (MRI images)\n"
+ "Also see https://www.med.upenn.edu/cbica/brats2020/",
+ 'url': (
+ "https://zenodo.org/record/4635763/files/Task082_nnUNetTrainerV2__nnUNetPlansv2.1_5fold.zip?download=1",
+ "https://zenodo.org/record/4635763/files/Task082_nnUNetTrainerV2BraTSRegions_DA3_BN_BD__nnUNetPlansv2.1_bs5_5fold.zip?download=1",
+ "https://zenodo.org/record/4635763/files/Task082_nnUNetTrainerV2BraTSRegions_DA4_BN__nnUNetPlansv2.1_bs5_15fold.zip?download=1",
+ "https://zenodo.org/record/4635763/files/Task082_nnUNetTrainerV2BraTSRegions_DA4_BN_BD__nnUNetPlansv2.1_bs5_5fold.zip?download=1",
+ )
+ },
+ "Task089_Fluo-N2DH-SIM_thickborder_time": {
+ 'description': "Fluo-N2DH-SIM dataset of the cell tracking challenge. Segmentation target are nuclei of N2DH cells and cell borders in fluorescence microscopy images.\n"
+ "Input modalities are 0: t minus 4, 0: t minus 3, 0: t minus 2, 0: t minus 1, 0: frame of interest\n"
+ "Note that the input channels are different time steps from a time series acquisition\n"
+ "Note that the segmentation output of the models are cell center and cell border. These outputs mus tbe converted to an instance segmentation for the challenge. \n"
+ "See https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task089_Fluo-N2DH-SIM.py\n"
+ "Also see http://celltrackingchallenge.net/",
+ 'url': "https://zenodo.org/record/4003545/files/Task089_Fluo-N2DH-SIM_thickborder_time.zip?download=1"
+ },
+ "Task114_heart_MNMs": {
+ 'description': "Cardiac MRI short axis images from the M&Ms challenge 2020.\n"
+ "Input modalities are 0: MRI \n"
+ "See also https://www.ub.edu/mnms/ \n"
+ "Note: Labels of the M&Ms Challenge are not in the same order as for the ACDC challenge. \n"
+ "See https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/dataset_conversion/Task114_heart_mnms.py",
+ 'url': "https://zenodo.org/record/4288464/files/Task114_heart_MNMs.zip?download=1"
+ },
+ "Task115_COVIDSegChallenge": {
+ 'description': "Covid lesion segmentation in CT images. Data originates from COVID-19-20 challenge.\n"
+ "Predicted labels are 0: background, 1: covid lesion\n"
+ "Input modalities are 0: CT \n"
+ "See also https://covid-segmentation.grand-challenge.org/",
+ 'url': (
+ "https://zenodo.org/record/4635822/files/Task115_nnUNetTrainerV2_DA3__nnUNetPlans_v2.1__3d_fullres__10folds.zip?download=1",
+ "https://zenodo.org/record/4635822/files/Task115_nnUNetTrainerV2_DA3_BN__nnUNetPlans_v2.1__3d_fullres__10folds.zip?download=1",
+ "https://zenodo.org/record/4635822/files/Task115_nnUNetTrainerV2_ResencUNet__nnUNetPlans_FabiansResUNet_v2.1__3d_fullres__10folds.zip?download=1",
+ "https://zenodo.org/record/4635822/files/Task115_nnUNetTrainerV2_ResencUNet_DA3__nnUNetPlans_FabiansResUNet_v2.1__3d_fullres__10folds.zip?download=1",
+ "https://zenodo.org/record/4635822/files/Task115_nnUNetTrainerV2_ResencUNet_DA3_BN__nnUNetPlans_FabiansResUNet_v2.1__3d_lowres__10folds.zip?download=1",
+ )
+ },
+ "Task135_KiTS2021": {
+ 'description': "Kidney and kidney tumor segmentation in CT images. Data originates from KiTS2021 challenge.\n"
+ "Predicted labels are 0: background, 1: kidney, 2: tumor, 3: cyst \n"
+ "Input modalities are 0: CT \n"
+ "See also https://kits21.kits-challenge.org/",
+ 'url': (
+ "https://zenodo.org/record/5126443/files/Task135_KiTS2021.zip?download=1",
+ )
+ },
+ }
+ return available_models
+
+
+def print_available_pretrained_models():
+ print('The following pretrained models are available:\n')
+ av_models = get_available_models()
+ for m in av_models.keys():
+ print('')
+ print(m)
+ print(av_models[m]['description'])
+
+
+def download_and_install_pretrained_model_by_name(taskname):
+ av_models = get_available_models()
+ if taskname not in av_models.keys():
+ raise RuntimeError("\nThe requested pretrained model ('%s') is not available." % taskname)
+ if len(av_models[taskname]['url']) == 0:
+ raise RuntimeError("The requested model has not been uploaded yet. Please check back in a few days")
+ url = av_models[taskname]['url']
+ if isinstance(url, str):
+ download_and_install_from_url(url)
+ elif isinstance(url, (tuple, list)):
+ for u in url:
+ download_and_install_from_url(u)
+ else:
+ raise RuntimeError('URL for download_and_install_from_url must be either str or list/tuple of str')
+
+
+def download_and_install_from_url(url):
+ assert network_training_output_dir is not None, "Cannot install model because network_training_output_dir is not " \
+ "set (RESULTS_FOLDER missing as environment variable, see " \
+ "Installation instructions)"
+ print('Downloading pretrained model from url:', url)
+ import http.client
+ http.client.HTTPConnection._http_vsn = 10
+ http.client.HTTPConnection._http_vsn_str = 'HTTP/1.0'
+
+ import os
+ home = os.path.expanduser('~')
+ random_number = int(time() * 1e7)
+ tempfile = join(home, '.nnunetdownload_%s' % str(random_number))
+
+ try:
+ download_file(url=url, local_filename=tempfile, chunk_size=8192 * 16)
+ print("Download finished. Extracting...")
+ install_model_from_zip_file(tempfile)
+ print("Done")
+ except Exception as e:
+ raise e
+ finally:
+ if isfile(tempfile):
+ os.remove(tempfile)
+
+
+def download_file(url: str, local_filename: str, chunk_size: Optional[int] = None) -> str:
+ # borrowed from https://stackoverflow.com/questions/16694907/download-large-file-in-python-with-requests
+ # NOTE the stream=True parameter below
+ with requests.get(url, stream=True) as r:
+ r.raise_for_status()
+ with open(local_filename, 'wb') as f:
+ for chunk in r.iter_content(chunk_size=chunk_size):
+ # If you have chunk encoded response uncomment if
+ # and set chunk_size parameter to None.
+ #if chunk:
+ f.write(chunk)
+ return local_filename
+
+
+def install_model_from_zip_file(zip_file: str):
+ with zipfile.ZipFile(zip_file, 'r') as zip_ref:
+ zip_ref.extractall(network_training_output_dir)
+
+
+def print_license_warning():
+ print('')
+ print('######################################################')
+ print('!!!!!!!!!!!!!!!!!!!!!!!!WARNING!!!!!!!!!!!!!!!!!!!!!!!')
+ print('######################################################')
+ print("Using the pretrained model weights is subject to the license of the dataset they were trained on. Some "
+ "allow commercial use, others don't. It is your responsibility to make sure you use them appropriately! Use "
+ "nnUNet_print_pretrained_model_info(task_name) to see a summary of the dataset and where to find its license!")
+ print('######################################################')
+ print('')
+
+
+def download_by_name():
+ import argparse
+ parser = argparse.ArgumentParser(description="Use this to download pretrained models. CAREFUL: This script will "
+ "overwrite "
+ "existing models (if they share the same trainer class and plans as "
+ "the pretrained model")
+ parser.add_argument("task_name", type=str, help='Task name of the pretrained model. To see '
+ 'available task names, run nnUNet_print_available_'
+ 'pretrained_models')
+ args = parser.parse_args()
+ taskname = args.task_name
+ print_license_warning()
+ download_and_install_pretrained_model_by_name(taskname)
+
+
+def download_by_url():
+ import argparse
+ parser = argparse.ArgumentParser(
+ description="Use this to download pretrained models. This script is intended to download models via url only. "
+ "If you want to download one of our pretrained models, please use nnUNet_download_pretrained_model. "
+ "CAREFUL: This script will overwrite "
+ "existing models (if they share the same trainer class and plans as "
+ "the pretrained model.")
+ parser.add_argument("url", type=str, help='URL of the pretrained model')
+ args = parser.parse_args()
+ url = args.url
+ download_and_install_from_url(url)
+
+
+def install_from_zip_entry_point():
+ import argparse
+ parser = argparse.ArgumentParser(
+ description="Use this to install a zip file containing a pretrained model.")
+ parser.add_argument("zip", type=str, help='zip file')
+ args = parser.parse_args()
+ zip = args.zip
+ install_model_from_zip_file(zip)
+
+
+def print_pretrained_model_requirements():
+ import argparse
+ parser = argparse.ArgumentParser(description="Use this to see the properties of a pretrained model, especially "
+ "what input modalities it requires")
+ parser.add_argument("task_name", type=str, help='Task name of the pretrained model. To see '
+ 'available task names, run nnUNet_print_available_'
+ 'pretrained_models')
+ args = parser.parse_args()
+ taskname = args.task_name
+ av = get_available_models()
+ if taskname not in av.keys():
+ raise RuntimeError("Invalid task name. This pretrained model does not exist. To see available task names, "
+ "run nnUNet_print_available_pretrained_models")
+ print(av[taskname]['description'])
+
+
+if __name__ == '__main__':
+ url = 'https://www.dropbox.com/s/ft54q1gi060vm2x/Task004_Hippocampus.zip?dl=1'
\ No newline at end of file
diff --git a/nnunet/inference/segmentation_export.py b/nnunet/inference/segmentation_export.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d5f2732e18110d827d665cff93bd4eaaf9d6d14
--- /dev/null
+++ b/nnunet/inference/segmentation_export.py
@@ -0,0 +1,238 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import sys
+from copy import deepcopy
+from typing import Union, Tuple
+
+import numpy as np
+import SimpleITK as sitk
+from batchgenerators.augmentations.utils import resize_segmentation
+from nnunet.preprocessing.preprocessing import get_lowres_axis, get_do_separate_z, resample_data_or_seg
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def save_segmentation_nifti_from_softmax(segmentation_softmax: Union[str, np.ndarray], out_fname: str,
+ properties_dict: dict, order: int = 1,
+ region_class_order: Tuple[Tuple[int]] = None,
+ seg_postprogess_fn: callable = None, seg_postprocess_args: tuple = None,
+ resampled_npz_fname: str = None,
+ non_postprocessed_fname: str = None, force_separate_z: bool = None,
+ interpolation_order_z: int = 0, verbose: bool = True):
+ """
+ This is a utility for writing segmentations to nifty and npz. It requires the data to have been preprocessed by
+ GenericPreprocessor because it depends on the property dictionary output (dct) to know the geometry of the original
+ data. segmentation_softmax does not have to have the same size in pixels as the original data, it will be
+ resampled to match that. This is generally useful because the spacings our networks operate on are most of the time
+ not the native spacings of the image data.
+ If seg_postprogess_fn is not None then seg_postprogess_fnseg_postprogess_fn(segmentation, *seg_postprocess_args)
+ will be called before nifty export
+ There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code.) We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray for segmentation_softmax and will handle this automatically
+ :param segmentation_softmax:
+ :param out_fname:
+ :param properties_dict:
+ :param order:
+ :param region_class_order:
+ :param seg_postprogess_fn:
+ :param seg_postprocess_args:
+ :param resampled_npz_fname:
+ :param non_postprocessed_fname:
+ :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always
+ /never resample along z separately. Do not touch unless you know what you are doing
+ :param interpolation_order_z: if separate z resampling is done then this is the order for resampling in z
+ :param verbose:
+ :return:
+ """
+ if verbose: print("force_separate_z:", force_separate_z, "interpolation order:", order)
+
+ if isinstance(segmentation_softmax, str):
+ assert isfile(segmentation_softmax), "If isinstance(segmentation_softmax, str) then " \
+ "isfile(segmentation_softmax) must be True"
+ del_file = deepcopy(segmentation_softmax)
+ if segmentation_softmax.endswith('.npy'):
+ segmentation_softmax = np.load(segmentation_softmax)
+ elif segmentation_softmax.endswith('.npz'):
+ segmentation_softmax = np.load(segmentation_softmax)['softmax']
+ os.remove(del_file)
+
+ # first resample, then put result into bbox of cropping, then save
+ current_shape = segmentation_softmax.shape
+ shape_original_after_cropping = properties_dict.get('size_after_cropping')
+ shape_original_before_cropping = properties_dict.get('original_size_of_raw_data')
+ # current_spacing = dct.get('spacing_after_resampling')
+ # original_spacing = dct.get('original_spacing')
+
+ if np.any([i != j for i, j in zip(np.array(current_shape[1:]), np.array(shape_original_after_cropping))]):
+ if force_separate_z is None:
+ if get_do_separate_z(properties_dict.get('original_spacing')):
+ do_separate_z = True
+ lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
+ elif get_do_separate_z(properties_dict.get('spacing_after_resampling')):
+ do_separate_z = True
+ lowres_axis = get_lowres_axis(properties_dict.get('spacing_after_resampling'))
+ else:
+ do_separate_z = False
+ lowres_axis = None
+ else:
+ do_separate_z = force_separate_z
+ if do_separate_z:
+ lowres_axis = get_lowres_axis(properties_dict.get('original_spacing'))
+ else:
+ lowres_axis = None
+
+ if lowres_axis is not None and len(lowres_axis) != 1:
+ # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
+ # separately in the out of plane axis
+ do_separate_z = False
+
+ if verbose: print("separate z:", do_separate_z, "lowres axis", lowres_axis)
+ seg_old_spacing = resample_data_or_seg(segmentation_softmax, shape_original_after_cropping, is_seg=False,
+ axis=lowres_axis, order=order, do_separate_z=do_separate_z,
+ order_z=interpolation_order_z)
+ # seg_old_spacing = resize_softmax_output(segmentation_softmax, shape_original_after_cropping, order=order)
+ else:
+ if verbose: print("no resampling necessary")
+ seg_old_spacing = segmentation_softmax
+
+ if resampled_npz_fname is not None:
+ np.savez_compressed(resampled_npz_fname, softmax=seg_old_spacing.astype(np.float16))
+ # this is needed for ensembling if the nonlinearity is sigmoid
+ if region_class_order is not None:
+ properties_dict['regions_class_order'] = region_class_order
+ save_pickle(properties_dict, resampled_npz_fname[:-4] + ".pkl")
+
+ if region_class_order is None:
+ seg_old_spacing = seg_old_spacing.argmax(0)
+ else:
+ seg_old_spacing_final = np.zeros(seg_old_spacing.shape[1:])
+ for i, c in enumerate(region_class_order):
+ seg_old_spacing_final[seg_old_spacing[i] > 0.5] = c
+ seg_old_spacing = seg_old_spacing_final
+
+ bbox = properties_dict.get('crop_bbox')
+
+ if bbox is not None:
+ seg_old_size = np.zeros(shape_original_before_cropping, dtype=np.uint8)
+ for c in range(3):
+ bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
+ seg_old_size[bbox[0][0]:bbox[0][1],
+ bbox[1][0]:bbox[1][1],
+ bbox[2][0]:bbox[2][1]] = seg_old_spacing
+ else:
+ seg_old_size = seg_old_spacing
+
+ if seg_postprogess_fn is not None:
+ seg_old_size_postprocessed = seg_postprogess_fn(np.copy(seg_old_size), *seg_postprocess_args)
+ else:
+ seg_old_size_postprocessed = seg_old_size
+
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_size_postprocessed.astype(np.uint8))
+ seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
+ seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
+ seg_resized_itk.SetDirection(properties_dict['itk_direction'])
+ sitk.WriteImage(seg_resized_itk, out_fname)
+
+ if (non_postprocessed_fname is not None) and (seg_postprogess_fn is not None):
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
+ seg_resized_itk.SetSpacing(properties_dict['itk_spacing'])
+ seg_resized_itk.SetOrigin(properties_dict['itk_origin'])
+ seg_resized_itk.SetDirection(properties_dict['itk_direction'])
+ sitk.WriteImage(seg_resized_itk, non_postprocessed_fname)
+
+
+def save_segmentation_nifti(segmentation, out_fname, dct, order=1, force_separate_z=None, order_z=0, verbose: bool = False):
+ """
+ faster and uses less ram than save_segmentation_nifti_from_softmax, but maybe less precise and also does not support
+ softmax export (which is needed for ensembling). So it's a niche function that may be useful in some cases.
+ :param segmentation:
+ :param out_fname:
+ :param dct:
+ :param order:
+ :param force_separate_z:
+ :return:
+ """
+ # suppress output
+ print("force_separate_z:", force_separate_z, "interpolation order:", order)
+ if not verbose:
+ sys.stdout = open(os.devnull, 'w')
+
+ if isinstance(segmentation, str):
+ assert isfile(segmentation), "If isinstance(segmentation_softmax, str) then " \
+ "isfile(segmentation_softmax) must be True"
+ del_file = deepcopy(segmentation)
+ segmentation = np.load(segmentation)
+ os.remove(del_file)
+
+ # first resample, then put result into bbox of cropping, then save
+ current_shape = segmentation.shape
+ shape_original_after_cropping = dct.get('size_after_cropping')
+ shape_original_before_cropping = dct.get('original_size_of_raw_data')
+ # current_spacing = dct.get('spacing_after_resampling')
+ # original_spacing = dct.get('original_spacing')
+
+ if np.any(np.array(current_shape) != np.array(shape_original_after_cropping)):
+ if order == 0:
+ seg_old_spacing = resize_segmentation(segmentation, shape_original_after_cropping, 0)
+ else:
+ if force_separate_z is None:
+ if get_do_separate_z(dct.get('original_spacing')):
+ do_separate_z = True
+ lowres_axis = get_lowres_axis(dct.get('original_spacing'))
+ elif get_do_separate_z(dct.get('spacing_after_resampling')):
+ do_separate_z = True
+ lowres_axis = get_lowres_axis(dct.get('spacing_after_resampling'))
+ else:
+ do_separate_z = False
+ lowres_axis = None
+ else:
+ do_separate_z = force_separate_z
+ if do_separate_z:
+ lowres_axis = get_lowres_axis(dct.get('original_spacing'))
+ else:
+ lowres_axis = None
+
+ print("separate z:", do_separate_z, "lowres axis", lowres_axis)
+ seg_old_spacing = resample_data_or_seg(segmentation[None], shape_original_after_cropping, is_seg=True,
+ axis=lowres_axis, order=order, do_separate_z=do_separate_z,
+ order_z=order_z)[0]
+ else:
+ seg_old_spacing = segmentation
+
+ bbox = dct.get('crop_bbox')
+
+ if bbox is not None:
+ seg_old_size = np.zeros(shape_original_before_cropping)
+ for c in range(3):
+ bbox[c][1] = np.min((bbox[c][0] + seg_old_spacing.shape[c], shape_original_before_cropping[c]))
+ seg_old_size[bbox[0][0]:bbox[0][1],
+ bbox[1][0]:bbox[1][1],
+ bbox[2][0]:bbox[2][1]] = seg_old_spacing
+ else:
+ seg_old_size = seg_old_spacing
+
+ seg_resized_itk = sitk.GetImageFromArray(seg_old_size.astype(np.uint8))
+ seg_resized_itk.SetSpacing(dct['itk_spacing'])
+ seg_resized_itk.SetOrigin(dct['itk_origin'])
+ seg_resized_itk.SetDirection(dct['itk_direction'])
+ sitk.WriteImage(seg_resized_itk, out_fname)
+
+ if not verbose:
+ sys.stdout = sys.__stdout__
diff --git a/nnunet/network_architecture/__init__.py b/nnunet/network_architecture/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/network_architecture/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/network_architecture/custom_modules/__init__.py b/nnunet/network_architecture/custom_modules/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/network_architecture/custom_modules/conv_blocks.py b/nnunet/network_architecture/custom_modules/conv_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..b674f101e3becaf2770cfb26c4f52f82176b00c3
--- /dev/null
+++ b/nnunet/network_architecture/custom_modules/conv_blocks.py
@@ -0,0 +1,228 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+from nnunet.network_architecture.custom_modules.helperModules import Identity
+from torch import nn
+
+
+class ConvDropoutNormReLU(nn.Module):
+ def __init__(self, input_channels, output_channels, kernel_size, network_props):
+ """
+ if network_props['dropout_op'] is None then no dropout
+ if network_props['norm_op'] is None then no norm
+ :param input_channels:
+ :param output_channels:
+ :param kernel_size:
+ :param network_props:
+ """
+ super(ConvDropoutNormReLU, self).__init__()
+
+ network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
+
+ self.conv = network_props['conv_op'](input_channels, output_channels, kernel_size,
+ padding=[(i - 1) // 2 for i in kernel_size],
+ **network_props['conv_op_kwargs'])
+
+ # maybe dropout
+ if network_props['dropout_op'] is not None:
+ self.do = network_props['dropout_op'](**network_props['dropout_op_kwargs'])
+ else:
+ self.do = Identity()
+
+ if network_props['norm_op'] is not None:
+ self.norm = network_props['norm_op'](output_channels, **network_props['norm_op_kwargs'])
+ else:
+ self.norm = Identity()
+
+ self.nonlin = network_props['nonlin'](**network_props['nonlin_kwargs'])
+
+ self.all = nn.Sequential(self.conv, self.do, self.norm, self.nonlin)
+
+ def forward(self, x):
+ return self.all(x)
+
+
+class StackedConvLayers(nn.Module):
+ def __init__(self, input_channels, output_channels, kernel_size, network_props, num_convs, first_stride=None):
+ """
+ if network_props['dropout_op'] is None then no dropout
+ if network_props['norm_op'] is None then no norm
+ :param input_channels:
+ :param output_channels:
+ :param kernel_size:
+ :param network_props:
+ """
+ super(StackedConvLayers, self).__init__()
+
+ network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
+ network_props_first = deepcopy(network_props)
+
+ if first_stride is not None:
+ network_props_first['conv_op_kwargs']['stride'] = first_stride
+
+ self.convs = nn.Sequential(
+ ConvDropoutNormReLU(input_channels, output_channels, kernel_size, network_props_first),
+ *[ConvDropoutNormReLU(output_channels, output_channels, kernel_size, network_props) for _ in
+ range(num_convs - 1)]
+ )
+
+ def forward(self, x):
+ return self.convs(x)
+
+
+class BasicResidualBlock(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, props, stride=None):
+ """
+ This is the conv bn nonlin conv bn nonlin kind of block
+ :param in_planes:
+ :param out_planes:
+ :param props:
+ :param override_stride:
+ """
+ super().__init__()
+
+ self.kernel_size = kernel_size
+ props['conv_op_kwargs']['stride'] = 1
+
+ self.stride = stride
+ self.props = props
+ self.out_planes = out_planes
+ self.in_planes = in_planes
+
+ if stride is not None:
+ kwargs_conv1 = deepcopy(props['conv_op_kwargs'])
+ kwargs_conv1['stride'] = stride
+ else:
+ kwargs_conv1 = props['conv_op_kwargs']
+
+ self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
+ **kwargs_conv1)
+ self.norm1 = props['norm_op'](out_planes, **props['norm_op_kwargs'])
+ self.nonlin1 = props['nonlin'](**props['nonlin_kwargs'])
+
+ if props['dropout_op_kwargs']['p'] != 0:
+ self.dropout = props['dropout_op'](**props['dropout_op_kwargs'])
+ else:
+ self.dropout = Identity()
+
+ self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
+ **props['conv_op_kwargs'])
+ self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs'])
+ self.nonlin2 = props['nonlin'](**props['nonlin_kwargs'])
+
+ if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes):
+ stride_here = stride if stride is not None else 1
+ self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False),
+ props['norm_op'](out_planes, **props['norm_op_kwargs']))
+ else:
+ self.downsample_skip = lambda x: x
+
+ def forward(self, x):
+ residual = x
+
+ out = self.dropout(self.conv1(x))
+ out = self.nonlin1(self.norm1(out))
+
+ out = self.norm2(self.conv2(out))
+
+ residual = self.downsample_skip(residual)
+
+ out += residual
+
+ return self.nonlin2(out)
+
+
+class ResidualBottleneckBlock(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, props, stride=None):
+ """
+ This is the conv bn nonlin conv bn nonlin kind of block
+ :param in_planes:
+ :param out_planes:
+ :param props:
+ :param override_stride:
+ """
+ super().__init__()
+
+ if props['dropout_op_kwargs'] is None and props['dropout_op_kwargs'] > 0:
+ raise NotImplementedError("ResidualBottleneckBlock does not yet support dropout!")
+
+ self.kernel_size = kernel_size
+ props['conv_op_kwargs']['stride'] = 1
+
+ self.stride = stride
+ self.props = props
+ self.out_planes = out_planes
+ self.in_planes = in_planes
+ self.bottleneck_planes = out_planes // 4
+
+ if stride is not None:
+ kwargs_conv1 = deepcopy(props['conv_op_kwargs'])
+ kwargs_conv1['stride'] = stride
+ else:
+ kwargs_conv1 = props['conv_op_kwargs']
+
+ self.conv1 = props['conv_op'](in_planes, self.bottleneck_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size],
+ **kwargs_conv1)
+ self.norm1 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs'])
+ self.nonlin1 = props['nonlin'](**props['nonlin_kwargs'])
+
+ self.conv2 = props['conv_op'](self.bottleneck_planes, self.bottleneck_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
+ **props['conv_op_kwargs'])
+ self.norm2 = props['norm_op'](self.bottleneck_planes, **props['norm_op_kwargs'])
+ self.nonlin2 = props['nonlin'](**props['nonlin_kwargs'])
+
+ self.conv3 = props['conv_op'](self.bottleneck_planes, out_planes, [1 for _ in kernel_size], padding=[0 for i in kernel_size],
+ **props['conv_op_kwargs'])
+ self.norm3 = props['norm_op'](out_planes, **props['norm_op_kwargs'])
+ self.nonlin3 = props['nonlin'](**props['nonlin_kwargs'])
+
+ if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes):
+ stride_here = stride if stride is not None else 1
+ self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False),
+ props['norm_op'](out_planes, **props['norm_op_kwargs']))
+ else:
+ self.downsample_skip = lambda x: x
+
+ def forward(self, x):
+ residual = x
+
+ out = self.nonlin1(self.norm1(self.conv1(x)))
+ out = self.nonlin2(self.norm2(self.conv2(out)))
+
+ out = self.norm3(self.conv3(out))
+
+ residual = self.downsample_skip(residual)
+
+ out += residual
+
+ return self.nonlin3(out)
+
+
+class ResidualLayer(nn.Module):
+ def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None, block=BasicResidualBlock):
+ super().__init__()
+
+ network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
+
+ self.convs = nn.Sequential(
+ block(input_channels, output_channels, kernel_size, network_props, first_stride),
+ *[block(output_channels, output_channels, kernel_size, network_props) for _ in
+ range(num_blocks - 1)]
+ )
+
+ def forward(self, x):
+ return self.convs(x)
+
diff --git a/nnunet/network_architecture/custom_modules/feature_response_normalization.py b/nnunet/network_architecture/custom_modules/feature_response_normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..558f9e6c9810b7ecdfbe3a776c6a0ff2192ed1f9
--- /dev/null
+++ b/nnunet/network_architecture/custom_modules/feature_response_normalization.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.utilities.tensor_utilities import mean_tensor
+from torch import nn
+import torch
+from torch.nn.parameter import Parameter
+import torch.jit
+
+
+class FRN3D(nn.Module):
+ def __init__(self, num_features: int, eps=1e-6, **kwargs):
+ super().__init__()
+ self.eps = eps
+ self.num_features = num_features
+ self.weight = Parameter(torch.ones(1, num_features, 1, 1, 1), True)
+ self.bias = Parameter(torch.zeros(1, num_features, 1, 1, 1), True)
+ self.tau = Parameter(torch.zeros(1, num_features, 1, 1, 1), True)
+
+ def forward(self, x: torch.Tensor):
+ x = x * torch.rsqrt(mean_tensor(x * x, [2, 3, 4], keepdim=True) + self.eps)
+
+ return torch.max(self.weight * x + self.bias, self.tau)
+
+
+if __name__ == "__main__":
+ tmp = torch.rand((3, 32, 16, 16, 16))
+
+ frn = FRN3D(32)
+
+ out = frn(tmp)
diff --git a/nnunet/network_architecture/custom_modules/helperModules.py b/nnunet/network_architecture/custom_modules/helperModules.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b6fab7f7608e6cd8ea0a89b0b94d14a9b62e88e
--- /dev/null
+++ b/nnunet/network_architecture/custom_modules/helperModules.py
@@ -0,0 +1,29 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from torch import nn
+
+
+class Identity(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ def forward(self, input):
+ return input
+
+
+class MyGroupNorm(nn.GroupNorm):
+ def __init__(self, num_channels, eps=1e-5, affine=True, num_groups=8):
+ super(MyGroupNorm, self).__init__(num_groups, num_channels, eps, affine)
diff --git a/nnunet/network_architecture/custom_modules/mish.py b/nnunet/network_architecture/custom_modules/mish.py
new file mode 100644
index 0000000000000000000000000000000000000000..73a723eb83482356720275785e7b1ee450b8476c
--- /dev/null
+++ b/nnunet/network_architecture/custom_modules/mish.py
@@ -0,0 +1,23 @@
+############
+# https://github.com/lessw2020/mish/blob/master/mish.py
+# This code was taken from the repo above and was not created by me (Fabian)! Full credit goes to the original authors
+############
+
+import torch
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+# Mish - "Mish: A Self Regularized Non-Monotonic Neural Activation Function"
+# https://arxiv.org/abs/1908.08681v1
+# implemented for PyTorch / FastAI by lessw2020
+# github: https://github.com/lessw2020/mish
+
+class Mish(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, x):
+ # inlining this saves 1 second per epoch (V100 GPU) vs having a temp x and then returning x(!)
+ return x * (torch.tanh(F.softplus(x)))
diff --git a/nnunet/network_architecture/generic_UNet.py b/nnunet/network_architecture/generic_UNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..95d351a6dbe6dbb8bdde77f5e4fa550bec02cc85
--- /dev/null
+++ b/nnunet/network_architecture/generic_UNet.py
@@ -0,0 +1,449 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+import torch
+import numpy as np
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+import torch.nn.functional
+
+
+class ConvDropoutNormNonlin(nn.Module):
+ """
+ fixes a bug in ConvDropoutNormNonlin where lrelu was used regardless of nonlin. Bad.
+ """
+
+ def __init__(self, input_channels, output_channels,
+ conv_op=nn.Conv2d, conv_kwargs=None,
+ norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
+ dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs=None):
+ super(ConvDropoutNormNonlin, self).__init__()
+ if nonlin_kwargs is None:
+ nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ if dropout_op_kwargs is None:
+ dropout_op_kwargs = {'p': 0.5, 'inplace': True}
+ if norm_op_kwargs is None:
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
+ if conv_kwargs is None:
+ conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
+
+ self.nonlin_kwargs = nonlin_kwargs
+ self.nonlin = nonlin
+ self.dropout_op = dropout_op
+ self.dropout_op_kwargs = dropout_op_kwargs
+ self.norm_op_kwargs = norm_op_kwargs
+ self.conv_kwargs = conv_kwargs
+ self.conv_op = conv_op
+ self.norm_op = norm_op
+
+ self.conv = self.conv_op(input_channels, output_channels, **self.conv_kwargs)
+ if self.dropout_op is not None and self.dropout_op_kwargs['p'] is not None and self.dropout_op_kwargs[
+ 'p'] > 0:
+ self.dropout = self.dropout_op(**self.dropout_op_kwargs)
+ else:
+ self.dropout = None
+ self.instnorm = self.norm_op(output_channels, **self.norm_op_kwargs)
+ self.lrelu = self.nonlin(**self.nonlin_kwargs)
+
+ def forward(self, x):
+ x = self.conv(x)
+ if self.dropout is not None:
+ x = self.dropout(x)
+ return self.lrelu(self.instnorm(x))
+
+
+class ConvDropoutNonlinNorm(ConvDropoutNormNonlin):
+ def forward(self, x):
+ x = self.conv(x)
+ if self.dropout is not None:
+ x = self.dropout(x)
+ return self.instnorm(self.lrelu(x))
+
+
+class StackedConvLayers(nn.Module):
+ def __init__(self, input_feature_channels, output_feature_channels, num_convs,
+ conv_op=nn.Conv2d, conv_kwargs=None,
+ norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
+ dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs=None, first_stride=None, basic_block=ConvDropoutNormNonlin):
+ '''
+ stacks ConvDropoutNormLReLU layers. initial_stride will only be applied to first layer in the stack. The other parameters affect all layers
+ :param input_feature_channels:
+ :param output_feature_channels:
+ :param num_convs:
+ :param dilation:
+ :param kernel_size:
+ :param padding:
+ :param dropout:
+ :param initial_stride:
+ :param conv_op:
+ :param norm_op:
+ :param dropout_op:
+ :param inplace:
+ :param neg_slope:
+ :param norm_affine:
+ :param conv_bias:
+ '''
+ self.input_channels = input_feature_channels
+ self.output_channels = output_feature_channels
+
+ if nonlin_kwargs is None:
+ nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ if dropout_op_kwargs is None:
+ dropout_op_kwargs = {'p': 0.5, 'inplace': True}
+ if norm_op_kwargs is None:
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
+ if conv_kwargs is None:
+ conv_kwargs = {'kernel_size': 3, 'stride': 1, 'padding': 1, 'dilation': 1, 'bias': True}
+
+ self.nonlin_kwargs = nonlin_kwargs
+ self.nonlin = nonlin
+ self.dropout_op = dropout_op
+ self.dropout_op_kwargs = dropout_op_kwargs
+ self.norm_op_kwargs = norm_op_kwargs
+ self.conv_kwargs = conv_kwargs
+ self.conv_op = conv_op
+ self.norm_op = norm_op
+
+ if first_stride is not None:
+ self.conv_kwargs_first_conv = deepcopy(conv_kwargs)
+ self.conv_kwargs_first_conv['stride'] = first_stride
+ else:
+ self.conv_kwargs_first_conv = conv_kwargs
+
+ super(StackedConvLayers, self).__init__()
+ self.blocks = nn.Sequential(
+ *([basic_block(input_feature_channels, output_feature_channels, self.conv_op,
+ self.conv_kwargs_first_conv,
+ self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
+ self.nonlin, self.nonlin_kwargs)] +
+ [basic_block(output_feature_channels, output_feature_channels, self.conv_op,
+ self.conv_kwargs,
+ self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
+ self.nonlin, self.nonlin_kwargs) for _ in range(num_convs - 1)]))
+
+ def forward(self, x):
+ return self.blocks(x)
+
+
+def print_module_training_status(module):
+ if isinstance(module, nn.Conv2d) or isinstance(module, nn.Conv3d) or isinstance(module, nn.Dropout3d) or \
+ isinstance(module, nn.Dropout2d) or isinstance(module, nn.Dropout) or isinstance(module, nn.InstanceNorm3d) \
+ or isinstance(module, nn.InstanceNorm2d) or isinstance(module, nn.InstanceNorm1d) \
+ or isinstance(module, nn.BatchNorm2d) or isinstance(module, nn.BatchNorm3d) or isinstance(module,
+ nn.BatchNorm1d):
+ print(str(module), module.training)
+
+
+class Upsample(nn.Module):
+ def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=False):
+ super(Upsample, self).__init__()
+ self.align_corners = align_corners
+ self.mode = mode
+ self.scale_factor = scale_factor
+ self.size = size
+
+ def forward(self, x):
+ return nn.functional.interpolate(x, size=self.size, scale_factor=self.scale_factor, mode=self.mode,
+ align_corners=self.align_corners)
+
+
+class Generic_UNet(SegmentationNetwork):
+ DEFAULT_BATCH_SIZE_3D = 2
+ DEFAULT_PATCH_SIZE_3D = (64, 192, 160)
+ SPACING_FACTOR_BETWEEN_STAGES = 2
+ BASE_NUM_FEATURES_3D = 30
+ MAX_NUMPOOL_3D = 999
+ MAX_NUM_FILTERS_3D = 320
+
+ DEFAULT_PATCH_SIZE_2D = (256, 256)
+ BASE_NUM_FEATURES_2D = 30
+ DEFAULT_BATCH_SIZE_2D = 50
+ MAX_NUMPOOL_2D = 999
+ MAX_FILTERS_2D = 480
+
+ use_this_for_batch_size_computation_2D = 19739648
+ use_this_for_batch_size_computation_3D = 520000000 # 505789440
+
+ def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2,
+ feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
+ norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
+ dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False,
+ final_nonlin=softmax_helper, weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
+ conv_kernel_sizes=None,
+ upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
+ max_num_features=None, basic_block=ConvDropoutNormNonlin,
+ seg_output_use_bias=False):
+ """
+ basically more flexible than v1, architecture is the same
+
+ Does this look complicated? Nah bro. Functionality > usability
+
+ This does everything you need, including world peace.
+
+ Questions? -> f.isensee@dkfz.de
+ """
+ super(Generic_UNet, self).__init__()
+ self.convolutional_upsampling = convolutional_upsampling
+ self.convolutional_pooling = convolutional_pooling
+ self.upscale_logits = upscale_logits
+ if nonlin_kwargs is None:
+ nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ if dropout_op_kwargs is None:
+ dropout_op_kwargs = {'p': 0.5, 'inplace': True}
+ if norm_op_kwargs is None:
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'momentum': 0.1}
+
+ self.conv_kwargs = {'stride': 1, 'dilation': 1, 'bias': True}
+
+ self.nonlin = nonlin
+ self.nonlin_kwargs = nonlin_kwargs
+ self.dropout_op_kwargs = dropout_op_kwargs
+ self.norm_op_kwargs = norm_op_kwargs
+ self.weightInitializer = weightInitializer
+ self.conv_op = conv_op
+ self.norm_op = norm_op
+ self.dropout_op = dropout_op
+ self.num_classes = num_classes
+ self.final_nonlin = final_nonlin
+ self._deep_supervision = deep_supervision
+ self.do_ds = deep_supervision
+
+ if conv_op == nn.Conv2d:
+ upsample_mode = 'bilinear'
+ pool_op = nn.MaxPool2d
+ transpconv = nn.ConvTranspose2d
+ if pool_op_kernel_sizes is None:
+ pool_op_kernel_sizes = [(2, 2)] * num_pool
+ if conv_kernel_sizes is None:
+ conv_kernel_sizes = [(3, 3)] * (num_pool + 1)
+ elif conv_op == nn.Conv3d:
+ upsample_mode = 'trilinear'
+ pool_op = nn.MaxPool3d
+ transpconv = nn.ConvTranspose3d
+ if pool_op_kernel_sizes is None:
+ pool_op_kernel_sizes = [(2, 2, 2)] * num_pool
+ if conv_kernel_sizes is None:
+ conv_kernel_sizes = [(3, 3, 3)] * (num_pool + 1)
+ else:
+ raise ValueError("unknown convolution dimensionality, conv op: %s" % str(conv_op))
+
+ self.input_shape_must_be_divisible_by = np.prod(pool_op_kernel_sizes, 0, dtype=np.int64)
+ self.pool_op_kernel_sizes = pool_op_kernel_sizes
+ self.conv_kernel_sizes = conv_kernel_sizes
+
+ self.conv_pad_sizes = []
+ for krnl in self.conv_kernel_sizes:
+ self.conv_pad_sizes.append([1 if i == 3 else 0 for i in krnl])
+
+ if max_num_features is None:
+ if self.conv_op == nn.Conv3d:
+ self.max_num_features = self.MAX_NUM_FILTERS_3D
+ else:
+ self.max_num_features = self.MAX_FILTERS_2D
+ else:
+ self.max_num_features = max_num_features
+
+ self.conv_blocks_context = []
+ self.conv_blocks_localization = []
+ self.td = []
+ self.tu = []
+ self.seg_outputs = []
+
+ output_features = base_num_features
+ input_features = input_channels
+
+ for d in range(num_pool):
+ # determine the first stride
+ if d != 0 and self.convolutional_pooling:
+ first_stride = pool_op_kernel_sizes[d - 1]
+ else:
+ first_stride = None
+
+ self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[d]
+ self.conv_kwargs['padding'] = self.conv_pad_sizes[d]
+ # add convolutions
+ self.conv_blocks_context.append(StackedConvLayers(input_features, output_features, num_conv_per_stage,
+ self.conv_op, self.conv_kwargs, self.norm_op,
+ self.norm_op_kwargs, self.dropout_op,
+ self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs,
+ first_stride, basic_block=basic_block))
+ if not self.convolutional_pooling:
+ self.td.append(pool_op(pool_op_kernel_sizes[d]))
+ input_features = output_features
+ output_features = int(np.round(output_features * feat_map_mul_on_downscale))
+
+ output_features = min(output_features, self.max_num_features)
+
+ # now the bottleneck.
+ # determine the first stride
+ if self.convolutional_pooling:
+ first_stride = pool_op_kernel_sizes[-1]
+ else:
+ first_stride = None
+
+ # the output of the last conv must match the number of features from the skip connection if we are not using
+ # convolutional upsampling. If we use convolutional upsampling then the reduction in feature maps will be
+ # done by the transposed conv
+ if self.convolutional_upsampling:
+ final_num_features = output_features
+ else:
+ final_num_features = self.conv_blocks_context[-1].output_channels
+
+ self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[num_pool]
+ self.conv_kwargs['padding'] = self.conv_pad_sizes[num_pool]
+ self.conv_blocks_context.append(nn.Sequential(
+ StackedConvLayers(input_features, output_features, num_conv_per_stage - 1, self.conv_op, self.conv_kwargs,
+ self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
+ self.nonlin_kwargs, first_stride, basic_block=basic_block),
+ StackedConvLayers(output_features, final_num_features, 1, self.conv_op, self.conv_kwargs,
+ self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs, self.nonlin,
+ self.nonlin_kwargs, basic_block=basic_block)))
+
+ # if we don't want to do dropout in the localization pathway then we set the dropout prob to zero here
+ if not dropout_in_localization:
+ old_dropout_p = self.dropout_op_kwargs['p']
+ self.dropout_op_kwargs['p'] = 0.0
+
+ # now lets build the localization pathway
+ for u in range(num_pool):
+ nfeatures_from_down = final_num_features
+ nfeatures_from_skip = self.conv_blocks_context[
+ -(2 + u)].output_channels # self.conv_blocks_context[-1] is bottleneck, so start with -2
+ n_features_after_tu_and_concat = nfeatures_from_skip * 2
+
+ # the first conv reduces the number of features to match those of skip
+ # the following convs work on that number of features
+ # if not convolutional upsampling then the final conv reduces the num of features again
+ if u != num_pool - 1 and not self.convolutional_upsampling:
+ final_num_features = self.conv_blocks_context[-(3 + u)].output_channels
+ else:
+ final_num_features = nfeatures_from_skip
+
+ if not self.convolutional_upsampling:
+ self.tu.append(Upsample(scale_factor=pool_op_kernel_sizes[-(u + 1)], mode=upsample_mode))
+ else:
+ self.tu.append(transpconv(nfeatures_from_down, nfeatures_from_skip, pool_op_kernel_sizes[-(u + 1)],
+ pool_op_kernel_sizes[-(u + 1)], bias=False))
+
+ self.conv_kwargs['kernel_size'] = self.conv_kernel_sizes[- (u + 1)]
+ self.conv_kwargs['padding'] = self.conv_pad_sizes[- (u + 1)]
+ self.conv_blocks_localization.append(nn.Sequential(
+ StackedConvLayers(n_features_after_tu_and_concat, nfeatures_from_skip, num_conv_per_stage - 1,
+ self.conv_op, self.conv_kwargs, self.norm_op, self.norm_op_kwargs, self.dropout_op,
+ self.dropout_op_kwargs, self.nonlin, self.nonlin_kwargs, basic_block=basic_block),
+ StackedConvLayers(nfeatures_from_skip, final_num_features, 1, self.conv_op, self.conv_kwargs,
+ self.norm_op, self.norm_op_kwargs, self.dropout_op, self.dropout_op_kwargs,
+ self.nonlin, self.nonlin_kwargs, basic_block=basic_block)
+ ))
+
+ for ds in range(len(self.conv_blocks_localization)):
+ self.seg_outputs.append(conv_op(self.conv_blocks_localization[ds][-1].output_channels, num_classes,
+ 1, 1, 0, 1, 1, seg_output_use_bias))
+
+ self.upscale_logits_ops = []
+ cum_upsample = np.cumprod(np.vstack(pool_op_kernel_sizes), axis=0)[::-1]
+ for usl in range(num_pool - 1):
+ if self.upscale_logits:
+ self.upscale_logits_ops.append(Upsample(scale_factor=tuple([int(i) for i in cum_upsample[usl + 1]]),
+ mode=upsample_mode))
+ else:
+ self.upscale_logits_ops.append(lambda x: x)
+
+ if not dropout_in_localization:
+ self.dropout_op_kwargs['p'] = old_dropout_p
+
+ # register all modules properly
+ self.conv_blocks_localization = nn.ModuleList(self.conv_blocks_localization)
+ self.conv_blocks_context = nn.ModuleList(self.conv_blocks_context)
+ self.td = nn.ModuleList(self.td)
+ self.tu = nn.ModuleList(self.tu)
+ self.seg_outputs = nn.ModuleList(self.seg_outputs)
+ if self.upscale_logits:
+ self.upscale_logits_ops = nn.ModuleList(
+ self.upscale_logits_ops) # lambda x:x is not a Module so we need to distinguish here
+
+ if self.weightInitializer is not None:
+ self.apply(self.weightInitializer)
+ # self.apply(print_module_training_status)
+
+ def forward(self, x):
+ skips = []
+ seg_outputs = []
+ for d in range(len(self.conv_blocks_context) - 1):
+ x = self.conv_blocks_context[d](x)
+ skips.append(x)
+ if not self.convolutional_pooling:
+ x = self.td[d](x)
+
+ x = self.conv_blocks_context[-1](x)
+
+ for u in range(len(self.tu)):
+ x = self.tu[u](x)
+ x = torch.cat((x, skips[-(u + 1)]), dim=1)
+ x = self.conv_blocks_localization[u](x)
+ seg_outputs.append(self.final_nonlin(self.seg_outputs[u](x)))
+
+ if self._deep_supervision and self.do_ds:
+ return tuple([seg_outputs[-1]] + [i(j) for i, j in
+ zip(list(self.upscale_logits_ops)[::-1], seg_outputs[:-1][::-1])])
+ else:
+ return seg_outputs[-1]
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, num_pool_per_axis, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, deep_supervision=False,
+ conv_per_stage=2):
+ """
+ This only applies for num_conv_per_stage and convolutional_upsampling=True
+ not real vram consumption. just a constant term to which the vram consumption will be approx proportional
+ (+ offset for parameter storage)
+ :param deep_supervision:
+ :param patch_size:
+ :param num_pool_per_axis:
+ :param base_num_features:
+ :param max_num_features:
+ :param num_modalities:
+ :param num_classes:
+ :param pool_op_kernel_sizes:
+ :return:
+ """
+ if not isinstance(num_pool_per_axis, np.ndarray):
+ num_pool_per_axis = np.array(num_pool_per_axis)
+
+ npool = len(pool_op_kernel_sizes)
+
+ map_size = np.array(patch_size)
+ tmp = np.int64((conv_per_stage * 2 + 1) * np.prod(map_size, dtype=np.int64) * base_num_features +
+ num_modalities * np.prod(map_size, dtype=np.int64) +
+ num_classes * np.prod(map_size, dtype=np.int64))
+
+ num_feat = base_num_features
+
+ for p in range(npool):
+ for pi in range(len(num_pool_per_axis)):
+ map_size[pi] /= pool_op_kernel_sizes[p][pi]
+ num_feat = min(num_feat * 2, max_num_features)
+ num_blocks = (conv_per_stage * 2 + 1) if p < (npool - 1) else conv_per_stage # conv_per_stage + conv_per_stage for the convs of encode/decode and 1 for transposed conv
+ tmp += num_blocks * np.prod(map_size, dtype=np.int64) * num_feat
+ if deep_supervision and p < (npool - 2):
+ tmp += np.prod(map_size, dtype=np.int64) * num_classes
+ # print(p, map_size, num_feat, tmp)
+ return tmp
diff --git a/nnunet/network_architecture/generic_UNet_DP.py b/nnunet/network_architecture/generic_UNet_DP.py
new file mode 100644
index 0000000000000000000000000000000000000000..02d6d2b574e7f48c177580873389362b090f1538
--- /dev/null
+++ b/nnunet/network_architecture/generic_UNet_DP.py
@@ -0,0 +1,124 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.loss_functions.dice_loss import get_tp_fp_fn_tn
+from nnunet.utilities.nd_softmax import softmax_helper
+from nnunet.utilities.tensor_utilities import sum_tensor
+from torch import nn
+
+
+class Generic_UNet_DP(Generic_UNet):
+ def __init__(self, input_channels, base_num_features, num_classes, num_pool, num_conv_per_stage=2,
+ feat_map_mul_on_downscale=2, conv_op=nn.Conv2d,
+ norm_op=nn.BatchNorm2d, norm_op_kwargs=None,
+ dropout_op=nn.Dropout2d, dropout_op_kwargs=None,
+ nonlin=nn.LeakyReLU, nonlin_kwargs=None, deep_supervision=True, dropout_in_localization=False,
+ weightInitializer=InitWeights_He(1e-2), pool_op_kernel_sizes=None,
+ conv_kernel_sizes=None,
+ upscale_logits=False, convolutional_pooling=False, convolutional_upsampling=False,
+ max_num_features=None):
+ """
+ As opposed to the Generic_UNet, this class will compute parts of the loss function in the forward pass. This is
+ useful for GPU parallelization. The batch DICE loss, if used, must be computed over the whole batch. Therefore, in a
+ naive implementation, all softmax outputs must be copied to a single GPU which will then
+ do the loss computation all by itself. In the context of 3D Segmentation, this results in a lot of overhead AND
+ is inefficient because the DICE computation is also kinda expensive (Think 8 GPUs with a result of shape
+ 2x4x128x128x128 each.). The DICE is a global metric, but its parts can be computed locally (TP, FP, FN). Thus,
+ this implementation will compute all the parts of the loss function in the forward pass (and thus in a
+ parallelized way). The results are very small (batch_size x num_classes for TP, FN and FP, respectively; scalar for CE) and
+ copied easily. Also the final steps of the loss function (computing batch dice and average CE values) are easy
+ and very quick on the one GPU they need to run on. BAM.
+ final_nonlin is lambda x:x here!
+ """
+ super(Generic_UNet_DP, self).__init__(input_channels, base_num_features, num_classes, num_pool,
+ num_conv_per_stage,
+ feat_map_mul_on_downscale, conv_op,
+ norm_op, norm_op_kwargs,
+ dropout_op, dropout_op_kwargs,
+ nonlin, nonlin_kwargs, deep_supervision, dropout_in_localization,
+ lambda x: x, weightInitializer, pool_op_kernel_sizes,
+ conv_kernel_sizes,
+ upscale_logits, convolutional_pooling, convolutional_upsampling,
+ max_num_features)
+ self.ce_loss = RobustCrossEntropyLoss()
+
+ def forward(self, x, y=None, return_hard_tp_fp_fn=False):
+ res = super(Generic_UNet_DP, self).forward(x) # regular Generic_UNet forward pass
+
+ if y is None:
+ return res
+ else:
+ # compute ce loss
+ if self._deep_supervision and self.do_ds:
+ ce_losses = [self.ce_loss(res[0], y[0]).unsqueeze(0)]
+ tps = []
+ fps = []
+ fns = []
+
+ res_softmax = softmax_helper(res[0])
+ tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[0])
+ tps.append(tp)
+ fps.append(fp)
+ fns.append(fn)
+ for i in range(1, len(y)):
+ ce_losses.append(self.ce_loss(res[i], y[i]).unsqueeze(0))
+ res_softmax = softmax_helper(res[i])
+ tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y[i])
+ tps.append(tp)
+ fps.append(fp)
+ fns.append(fn)
+ ret = ce_losses, tps, fps, fns
+ else:
+ ce_loss = self.ce_loss(res, y).unsqueeze(0)
+
+ # tp fp and fn need the output to be softmax
+ res_softmax = softmax_helper(res)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(res_softmax, y)
+
+ ret = ce_loss, tp, fp, fn
+
+ if return_hard_tp_fp_fn:
+ if self._deep_supervision and self.do_ds:
+ output = res[0]
+ target = y[0]
+ else:
+ target = y
+ output = res
+
+ with torch.no_grad():
+ num_classes = output.shape[1]
+ output_softmax = softmax_helper(output)
+ output_seg = output_softmax.argmax(1)
+ target = target[:, 0]
+ axes = tuple(range(1, len(target.shape)))
+ tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ for c in range(1, num_classes):
+ tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
+ fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
+ fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
+
+ tp_hard = tp_hard.sum(0, keepdim=False)[None]
+ fp_hard = fp_hard.sum(0, keepdim=False)[None]
+ fn_hard = fn_hard.sum(0, keepdim=False)[None]
+
+ ret = *ret, tp_hard, fp_hard, fn_hard
+ return ret
diff --git a/nnunet/network_architecture/generic_modular_UNet.py b/nnunet/network_architecture/generic_modular_UNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4901449c3a9b1afb5ae91ac8166749bae35d2d4
--- /dev/null
+++ b/nnunet/network_architecture/generic_modular_UNet.py
@@ -0,0 +1,470 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.network_architecture.custom_modules.conv_blocks import StackedConvLayers
+from nnunet.network_architecture.generic_UNet import Upsample
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from torch import nn
+import numpy as np
+from torch.optim import SGD
+
+"""
+The idea behind this modular U-net ist that we decouple encoder and decoder and thus make things a) a lot more easy to
+combine and b) enable easy swapping between segmentation or classification mode of the same architecture
+"""
+
+
+def get_default_network_config(dim=2, dropout_p=None, nonlin="LeakyReLU", norm_type="bn"):
+ """
+ returns a dictionary that contains pointers to conv, nonlin and norm ops and the default kwargs I like to use
+ :return:
+ """
+ props = {}
+ if dim == 2:
+ props['conv_op'] = nn.Conv2d
+ props['dropout_op'] = nn.Dropout2d
+ elif dim == 3:
+ props['conv_op'] = nn.Conv3d
+ props['dropout_op'] = nn.Dropout3d
+ else:
+ raise NotImplementedError
+
+ if norm_type == "bn":
+ if dim == 2:
+ props['norm_op'] = nn.BatchNorm2d
+ elif dim == 3:
+ props['norm_op'] = nn.BatchNorm3d
+ props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
+ elif norm_type == "in":
+ if dim == 2:
+ props['norm_op'] = nn.InstanceNorm2d
+ elif dim == 3:
+ props['norm_op'] = nn.InstanceNorm3d
+ props['norm_op_kwargs'] = {'eps': 1e-5, 'affine': True}
+ else:
+ raise NotImplementedError
+
+ if dropout_p is None:
+ props['dropout_op'] = None
+ props['dropout_op_kwargs'] = {'p': 0, 'inplace': True}
+ else:
+ props['dropout_op_kwargs'] = {'p': dropout_p, 'inplace': True}
+
+ props['conv_op_kwargs'] = {'stride': 1, 'dilation': 1, 'bias': True} # kernel size will be set by network!
+
+ if nonlin == "LeakyReLU":
+ props['nonlin'] = nn.LeakyReLU
+ props['nonlin_kwargs'] = {'negative_slope': 1e-2, 'inplace': True}
+ elif nonlin == "ReLU":
+ props['nonlin'] = nn.ReLU
+ props['nonlin_kwargs'] = {'inplace': True}
+ else:
+ raise ValueError
+
+ return props
+
+
+
+class PlainConvUNetEncoder(nn.Module):
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True,
+ max_num_features=480):
+ """
+ Following UNet building blocks can be added by utilizing the properties this class exposes (TODO)
+
+ this one includes the bottleneck layer!
+
+ :param input_channels:
+ :param base_num_features:
+ :param num_blocks_per_stage:
+ :param feat_map_mul_on_downscale:
+ :param pool_op_kernel_sizes:
+ :param conv_kernel_sizes:
+ :param props:
+ """
+ super(PlainConvUNetEncoder, self).__init__()
+
+ self.default_return_skips = default_return_skips
+ self.props = props
+
+ self.stages = []
+ self.stage_output_features = []
+ self.stage_pool_kernel_size = []
+ self.stage_conv_op_kernel_size = []
+
+ assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes)
+
+ num_stages = len(conv_kernel_sizes)
+
+ if not isinstance(num_blocks_per_stage, (list, tuple)):
+ num_blocks_per_stage = [num_blocks_per_stage] * num_stages
+ else:
+ assert len(num_blocks_per_stage) == num_stages
+
+ self.num_blocks_per_stage = num_blocks_per_stage # decoder may need this
+
+ current_input_features = input_channels
+ for stage in range(num_stages):
+ current_output_features = min(int(base_num_features * feat_map_mul_on_downscale ** stage), max_num_features)
+ current_kernel_size = conv_kernel_sizes[stage]
+ current_pool_kernel_size = pool_op_kernel_sizes[stage]
+
+ current_stage = StackedConvLayers(current_input_features, current_output_features, current_kernel_size,
+ props, num_blocks_per_stage[stage], current_pool_kernel_size)
+
+ self.stages.append(current_stage)
+ self.stage_output_features.append(current_output_features)
+ self.stage_conv_op_kernel_size.append(current_kernel_size)
+ self.stage_pool_kernel_size.append(current_pool_kernel_size)
+
+ # update current_input_features
+ current_input_features = current_output_features
+
+ self.stages = nn.ModuleList(self.stages)
+ self.output_features = current_output_features
+
+ def forward(self, x, return_skips=None):
+ """
+
+ :param x:
+ :param return_skips: if none then self.default_return_skips is used
+ :return:
+ """
+ skips = []
+
+ for s in self.stages:
+ x = s(x)
+ if self.default_return_skips:
+ skips.append(x)
+
+ if return_skips is None:
+ return_skips = self.default_return_skips
+
+ if return_skips:
+ return skips
+ else:
+ return x
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size):
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+
+ tmp = num_blocks_per_stage_encoder[0] * np.prod(current_shape) * base_num_features \
+ + num_modalities * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool + 1):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_blocks_per_stage_encoder[p]
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+ return tmp * batch_size
+
+
+class PlainConvUNetDecoder(nn.Module):
+ def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
+ upscale_logits=False):
+ super(PlainConvUNetDecoder, self).__init__()
+ self.num_classes = num_classes
+ self.deep_supervision = deep_supervision
+ """
+ We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
+ """
+ previous_stages = previous.stages
+ previous_stage_output_features = previous.stage_output_features
+ previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
+ previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size
+
+ if network_props is None:
+ self.props = previous.props
+ else:
+ self.props = network_props
+
+ if self.props['conv_op'] == nn.Conv2d:
+ transpconv = nn.ConvTranspose2d
+ upsample_mode = "bilinear"
+ elif self.props['conv_op'] == nn.Conv3d:
+ transpconv = nn.ConvTranspose3d
+ upsample_mode = "trilinear"
+ else:
+ raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))
+
+ if num_blocks_per_stage is None:
+ num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
+
+ assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1
+
+ self.stage_pool_kernel_size = previous_stage_pool_kernel_size
+ self.stage_output_features = previous_stage_output_features
+ self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size
+
+ num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the
+ # bottleneck
+
+ self.tus = []
+ self.stages = []
+ self.deep_supervision_outputs = []
+
+ # only used for upsample_logits
+ cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)
+
+ for i, s in enumerate(np.arange(num_stages)[::-1]):
+ features_below = previous_stage_output_features[s + 1]
+ features_skip = previous_stage_output_features[s]
+
+ self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
+ previous_stage_pool_kernel_size[s + 1], bias=False))
+ # after we tu we concat features so now we have 2xfeatures_skip
+ self.stages.append(StackedConvLayers(2 * features_skip, features_skip,
+ previous_stage_conv_op_kernel_size[s], self.props,
+ num_blocks_per_stage[i]))
+
+ if deep_supervision and s != 0:
+ seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
+ if upscale_logits:
+ upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
+ self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample))
+ else:
+ self.deep_supervision_outputs.append(seg_layer)
+
+ self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
+
+ self.tus = nn.ModuleList(self.tus)
+ self.stages = nn.ModuleList(self.stages)
+ self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
+
+ def forward(self, skips, gt=None, loss=None):
+ # skips come from the encoder. They are sorted so that the bottleneck is last in the list
+ # what is maybe not perfect is that the TUs and stages here are sorted the other way around
+ # so let's just reverse the order of skips
+ skips = skips[::-1]
+ seg_outputs = []
+
+ x = skips[0] # this is the bottleneck
+
+ for i in range(len(self.tus)):
+ x = self.tus[i](x)
+ x = torch.cat((x, skips[i + 1]), dim=1)
+ x = self.stages[i](x)
+ if self.deep_supervision and (i != len(self.tus) - 1):
+ tmp = self.deep_supervision_outputs[i](x)
+ if gt is not None:
+ tmp = loss(tmp, gt)
+ seg_outputs.append(tmp)
+
+ segmentation = self.segmentation_output(x)
+
+ if self.deep_supervision:
+ tmp = segmentation
+ if gt is not None:
+ tmp = loss(tmp, gt)
+ seg_outputs.append(tmp)
+ return seg_outputs[::-1] # seg_outputs are ordered so that the seg from the highest layer is first, the seg from
+ # the bottleneck of the UNet last
+ else:
+ return segmentation
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size):
+ """
+ This only applies for num_blocks_per_stage and convolutional_upsampling=True
+ not real vram consumption. just a constant term to which the vram consumption will be approx proportional
+ (+ offset for parameter storage)
+ :param patch_size:
+ :param num_pool_per_axis:
+ :param base_num_features:
+ :param max_num_features:
+ :return:
+ """
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+ tmp = (num_blocks_per_stage_decoder[-1] + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_blocks_per_stage_decoder[-(p+1)] + 1
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+
+ return tmp * batch_size
+
+
+class PlainConvUNet(SegmentationNetwork):
+ use_this_for_batch_size_computation_2D = 1167982592.0
+ use_this_for_batch_size_computation_3D = 1152286720.0
+
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
+ deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
+ super(PlainConvUNet, self).__init__()
+ self.conv_op = props['conv_op']
+ self.num_classes = num_classes
+
+ self.encoder = PlainConvUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
+ props, default_return_skips=True, max_num_features=max_features)
+ self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
+ deep_supervision, upscale_logits)
+ if initializer is not None:
+ self.apply(initializer)
+
+ def forward(self, x):
+ skips = self.encoder(x)
+ return self.decoder(skips)
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
+ num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
+ enc = PlainConvUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes,
+ num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size)
+ dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes,
+ num_blocks_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size)
+
+ return enc + dec
+
+ @staticmethod
+ def compute_reference_for_vram_consumption_3d():
+ patch_size = (160, 128, 128)
+ pool_op_kernel_sizes = ((1, 1, 1),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2))
+ conv_per_stage_encoder = (2, 2, 2, 2, 2, 2)
+ conv_per_stage_decoder = (2, 2, 2, 2, 2)
+
+ return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes,
+ conv_per_stage_encoder, conv_per_stage_decoder, 2, 2)
+
+ @staticmethod
+ def compute_reference_for_vram_consumption_2d():
+ patch_size = (256, 256)
+ pool_op_kernel_sizes = (
+ (1, 1), # (256, 256)
+ (2, 2), # (128, 128)
+ (2, 2), # (64, 64)
+ (2, 2), # (32, 32)
+ (2, 2), # (16, 16)
+ (2, 2), # (8, 8)
+ (2, 2) # (4, 4)
+ )
+ conv_per_stage_encoder = (2, 2, 2, 2, 2, 2, 2)
+ conv_per_stage_decoder = (2, 2, 2, 2, 2, 2)
+
+ return PlainConvUNet.compute_approx_vram_consumption(patch_size, 32, 512, 4, 3, pool_op_kernel_sizes,
+ conv_per_stage_encoder, conv_per_stage_decoder, 2, 56)
+
+
+if __name__ == "__main__":
+ conv_op_kernel_sizes = ((3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3))
+ pool_op_kernel_sizes = ((1, 1),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2))
+ patch_size = (256, 256)
+ batch_size = 56
+ unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes,
+ get_default_network_config(2, dropout_p=None), 4, (2, 2, 2, 2, 2, 2), False, False, max_features=512).cuda()
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+
+ unet.compute_reference_for_vram_consumption_3d()
+ unet.compute_reference_for_vram_consumption_2d()
+
+ dummy_input = torch.rand((batch_size, 4, *patch_size)).cuda()
+ dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long()
+
+ optimizer.zero_grad()
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
+ 'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {})
+ output = unet.decoder(skips)
+
+ l = loss(output, dummy_gt)
+ l.backward()
+
+ optimizer.step()
+
+ import hiddenlayer as hl
+ g = hl.build_graph(unet, dummy_input)
+ g.save("/home/fabian/test.pdf")
+
+ """conv_op_kernel_sizes = ((3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3))
+ pool_op_kernel_sizes = ((1, 1, 1),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2))
+ patch_size = (160, 128, 128)
+ unet = PlainConvUNet(4, 32, (2, 2, 2, 2, 2, 2), 2, pool_op_kernel_sizes, conv_op_kernel_sizes,
+ get_default_network_config(3, dropout_p=None), 4, (2, 2, 2, 2, 2), False, False, max_features=512).cuda()
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+
+ unet.compute_reference_for_vram_consumption_3d()
+ unet.compute_reference_for_vram_consumption_2d()
+
+ dummy_input = torch.rand((2, 4, *patch_size)).cuda()
+ dummy_gt = (torch.rand((2, 1, *patch_size)) * 4).round().clamp_(0, 3).cuda().long()
+
+ optimizer.zero_grad()
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'smooth_in_nom': True,
+ 'do_bg': False, 'rebalance_weights': None, 'background_weight': 1}, {})
+ output = unet.decoder(skips)
+
+ l = loss(output, dummy_gt)
+ l.backward()
+
+ optimizer.step()
+
+ import hiddenlayer as hl
+ g = hl.build_graph(unet, dummy_input)
+ g.save("/home/fabian/test.pdf")"""
diff --git a/nnunet/network_architecture/generic_modular_preact_residual_UNet.py b/nnunet/network_architecture/generic_modular_preact_residual_UNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..7846f8252465ccb25e711e2366e3856ba44ef427
--- /dev/null
+++ b/nnunet/network_architecture/generic_modular_preact_residual_UNet.py
@@ -0,0 +1,619 @@
+import numpy as np
+from copy import deepcopy
+import torch
+from torch.backends import cudnn
+from torch.cuda.amp import GradScaler, autocast
+from torch.nn import Identity
+
+from nnunet.network_architecture.generic_UNet import Upsample
+from nnunet.network_architecture.generic_modular_UNet import PlainConvUNetDecoder, get_default_network_config
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from torch import nn
+from torch.optim import SGD
+
+
+class BasicPreActResidualBlock(nn.Module):
+ def __init__(self, in_planes, out_planes, kernel_size, props, stride=None):
+ """
+ This is norm nonlin conv norm nonlin conv
+ :param in_planes:
+ :param out_planes:
+ :param props:
+ :param override_stride:
+ """
+ super().__init__()
+
+ self.kernel_size = kernel_size
+ props['conv_op_kwargs']['stride'] = 1
+
+ self.stride = stride
+ self.props = props
+ self.out_planes = out_planes
+ self.in_planes = in_planes
+
+ if stride is not None:
+ kwargs_conv1 = deepcopy(props['conv_op_kwargs'])
+ kwargs_conv1['stride'] = stride
+ else:
+ kwargs_conv1 = props['conv_op_kwargs']
+
+ self.norm1 = props['norm_op'](in_planes, **props['norm_op_kwargs'])
+ self.nonlin1 = props['nonlin'](**props['nonlin_kwargs'])
+ self.conv1 = props['conv_op'](in_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
+ **kwargs_conv1)
+
+ if props['dropout_op_kwargs']['p'] != 0:
+ self.dropout = props['dropout_op'](**props['dropout_op_kwargs'])
+ else:
+ self.dropout = Identity()
+
+ self.norm2 = props['norm_op'](out_planes, **props['norm_op_kwargs'])
+ self.nonlin2 = props['nonlin'](**props['nonlin_kwargs'])
+ self.conv2 = props['conv_op'](out_planes, out_planes, kernel_size, padding=[(i - 1) // 2 for i in kernel_size],
+ **props['conv_op_kwargs'])
+
+ if (self.stride is not None and any((i != 1 for i in self.stride))) or (in_planes != out_planes):
+ stride_here = stride if stride is not None else 1
+ self.downsample_skip = nn.Sequential(props['conv_op'](in_planes, out_planes, 1, stride_here, bias=False))
+ else:
+ self.downsample_skip = None
+
+ def forward(self, x):
+ residual = x
+
+ out = self.nonlin1(self.norm1(x))
+
+ if self.downsample_skip is not None:
+ residual = self.downsample_skip(out)
+
+ # norm nonlin conv
+ out = self.conv1(out)
+
+ out = self.dropout(out) # this does nothing if props['dropout_op_kwargs'] == 0
+
+ # norm nonlin conv
+ out = self.conv2(self.nonlin2(self.norm2(out)))
+
+ out += residual
+
+ return out
+
+
+class PreActResidualLayer(nn.Module):
+ def __init__(self, input_channels, output_channels, kernel_size, network_props, num_blocks, first_stride=None):
+ super().__init__()
+
+ network_props = deepcopy(network_props) # network_props is a dict and mutable, so we deepcopy to be safe.
+
+ self.convs = nn.Sequential(
+ BasicPreActResidualBlock(input_channels, output_channels, kernel_size, network_props, first_stride),
+ *[BasicPreActResidualBlock(output_channels, output_channels, kernel_size, network_props) for _ in
+ range(num_blocks - 1)]
+ )
+
+ def forward(self, x):
+ return self.convs(x)
+
+
+class PreActResidualUNetEncoder(nn.Module):
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True,
+ max_num_features=480, pool_type: str = 'conv'):
+ """
+ Following UNet building blocks can be added by utilizing the properties this class exposes (TODO)
+
+ this one includes the bottleneck layer!
+
+ :param input_channels:
+ :param base_num_features:
+ :param num_blocks_per_stage:
+ :param feat_map_mul_on_downscale:
+ :param pool_op_kernel_sizes:
+ :param conv_kernel_sizes:
+ :param props:
+ """
+ super(PreActResidualUNetEncoder, self).__init__()
+
+ self.default_return_skips = default_return_skips
+ self.props = props
+
+ pool_op = self._handle_pool(pool_type)
+
+ self.stages = []
+ self.stage_output_features = []
+ self.stage_pool_kernel_size = []
+ self.stage_conv_op_kernel_size = []
+
+ assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes)
+
+ num_stages = len(conv_kernel_sizes)
+
+ if not isinstance(num_blocks_per_stage, (list, tuple)):
+ num_blocks_per_stage = [num_blocks_per_stage] * num_stages
+ else:
+ assert len(num_blocks_per_stage) == num_stages
+
+ self.num_blocks_per_stage = num_blocks_per_stage # decoder may need this
+
+ self.initial_conv = props['conv_op'](input_channels, base_num_features, 3, padding=1, **props['conv_op_kwargs'])
+
+ current_input_features = base_num_features
+ for stage in range(num_stages):
+ current_output_features = min(base_num_features * feat_map_mul_on_downscale ** stage, max_num_features)
+ current_kernel_size = conv_kernel_sizes[stage]
+
+ current_pool_kernel_size = pool_op_kernel_sizes[stage]
+ if pool_op is not None:
+ pool_kernel_size_for_conv = [1 for i in current_pool_kernel_size]
+ else:
+ pool_kernel_size_for_conv = current_pool_kernel_size
+
+ current_stage = PreActResidualLayer(current_input_features, current_output_features, current_kernel_size, props,
+ self.num_blocks_per_stage[stage], pool_kernel_size_for_conv)
+ if pool_op is not None:
+ current_stage = nn.Sequential(pool_op(current_pool_kernel_size), current_stage)
+
+ self.stages.append(current_stage)
+ self.stage_output_features.append(current_output_features)
+ self.stage_conv_op_kernel_size.append(current_kernel_size)
+ self.stage_pool_kernel_size.append(current_pool_kernel_size)
+
+ # update current_input_features
+ current_input_features = current_output_features
+
+ self.stages = nn.ModuleList(self.stages)
+ self.output_features = current_input_features
+
+ def _handle_pool(self, pool_type):
+ assert pool_type in ['conv', 'avg', 'max']
+ if pool_type == 'avg':
+ if self.props['conv_op'] == nn.Conv2d:
+ pool_op = nn.AvgPool2d
+ elif self.props['conv_op'] == nn.Conv3d:
+ pool_op = nn.AvgPool3d
+ else:
+ raise NotImplementedError
+ elif pool_type == 'max':
+ if self.props['conv_op'] == nn.Conv2d:
+ pool_op = nn.MaxPool2d
+ elif self.props['conv_op'] == nn.Conv3d:
+ pool_op = nn.MaxPool3d
+ else:
+ raise NotImplementedError
+ elif pool_type == 'conv':
+ pool_op = None
+ else:
+ raise ValueError
+ return pool_op
+
+ def forward(self, x, return_skips=None):
+ """
+
+ :param x:
+ :param return_skips: if none then self.default_return_skips is used
+ :return:
+ """
+ skips = []
+
+ x = self.initial_conv(x)
+
+ for s in self.stages:
+ x = s(x)
+ if self.default_return_skips:
+ skips.append(x)
+
+ if return_skips is None:
+ return_skips = self.default_return_skips
+
+ if return_skips:
+ return skips
+ else:
+ return x
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes, num_conv_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size):
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+
+ tmp = (num_conv_per_stage_encoder[0] * 2 + 1) * np.prod(current_shape) * base_num_features \
+ + num_modalities * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool + 1):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_conv_per_stage_encoder[p] * 2 + 1 # + 1 for conv in skip in first block
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+ return tmp * batch_size
+
+
+class PreActResidualUNetDecoder(nn.Module):
+ def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
+ upscale_logits=False):
+ super(PreActResidualUNetDecoder, self).__init__()
+ self.num_classes = num_classes
+ self.deep_supervision = deep_supervision
+ """
+ We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
+ """
+ previous_stages = previous.stages
+ previous_stage_output_features = previous.stage_output_features
+ previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
+ previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size
+
+ if network_props is None:
+ self.props = previous.props
+ else:
+ self.props = network_props
+
+ if self.props['conv_op'] == nn.Conv2d:
+ transpconv = nn.ConvTranspose2d
+ upsample_mode = "bilinear"
+ elif self.props['conv_op'] == nn.Conv3d:
+ transpconv = nn.ConvTranspose3d
+ upsample_mode = "trilinear"
+ else:
+ raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))
+
+ if num_blocks_per_stage is None:
+ num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
+
+ assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1
+
+ self.stage_pool_kernel_size = previous_stage_pool_kernel_size
+ self.stage_output_features = previous_stage_output_features
+ self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size
+
+ num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the
+ # bottleneck
+
+ self.tus = []
+ self.stages = []
+ self.deep_supervision_outputs = []
+
+ # only used for upsample_logits
+ cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)
+
+ for i, s in enumerate(np.arange(num_stages)[::-1]):
+ features_below = previous_stage_output_features[s + 1]
+ features_skip = previous_stage_output_features[s]
+
+ self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
+ previous_stage_pool_kernel_size[s + 1], bias=False))
+ # after we tu we concat features so now we have 2xfeatures_skip
+ self.stages.append(PreActResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s],
+ self.props, num_blocks_per_stage[i], None))
+
+ if deep_supervision and s != 0:
+ norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs'])
+ nonlin = self.props['nonlin'](**self.props['nonlin_kwargs'])
+ seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True)
+ if upscale_logits:
+ upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
+ self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer, upsample))
+ else:
+ self.deep_supervision_outputs.append(nn.Sequential(norm, nonlin, seg_layer))
+
+ self.segmentation_conv_norm = self.props['norm_op'](features_skip, **self.props['norm_op_kwargs'])
+ self.segmentation_conv_nonlin = self.props['nonlin'](**self.props['nonlin_kwargs'])
+ self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, bias=True)
+ self.segmentation_output = nn.Sequential(self.segmentation_conv_norm, self.segmentation_conv_nonlin,
+ self.segmentation_output)
+
+ self.tus = nn.ModuleList(self.tus)
+ self.stages = nn.ModuleList(self.stages)
+ self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
+
+ def forward(self, skips):
+ # skips come from the encoder. They are sorted so that the bottleneck is last in the list
+ # what is maybe not perfect is that the TUs and stages here are sorted the other way around
+ # so let's just reverse the order of skips
+ skips = skips[::-1]
+ seg_outputs = []
+
+ x = skips[0] # this is the bottleneck
+
+ for i in range(len(self.tus)):
+ x = self.tus[i](x)
+ x = torch.cat((x, skips[i + 1]), dim=1)
+ x = self.stages[i](x)
+ if self.deep_supervision and (i != len(self.tus) - 1):
+ seg_outputs.append(self.deep_supervision_outputs[i](x))
+
+ segmentation = self.segmentation_output(x)
+
+ if self.deep_supervision:
+ seg_outputs.append(segmentation)
+ return seg_outputs[::-1] # seg_outputs are ordered so that the seg from the highest layer is first, the seg from
+ # the bottleneck of the UNet last
+ else:
+ return segmentation
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size):
+ """
+ This only applies for num_conv_per_stage and convolutional_upsampling=True
+ not real vram consumption. just a constant term to which the vram consumption will be approx proportional
+ (+ offset for parameter storage)
+ :param patch_size:
+ :param num_pool_per_axis:
+ :param base_num_features:
+ :param max_num_features:
+ :return:
+ """
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+ tmp = (num_blocks_per_stage_decoder[-1] * 2 + 1) * np.prod(current_shape) * base_num_features + num_classes * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_blocks_per_stage_decoder[-(p + 1)] * 2 + 1 + 1 # +1 for transpconv and +1 for conv in skip
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+
+ return tmp * batch_size
+
+
+class PreActResidualUNet(SegmentationNetwork):
+ use_this_for_batch_size_computation_2D = 858931200.0 # 1167982592.0
+ use_this_for_batch_size_computation_3D = 727842816.0 # 1152286720.0
+
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
+ deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
+ super(PreActResidualUNet, self).__init__()
+ self.conv_op = props['conv_op']
+ self.num_classes = num_classes
+
+ self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
+ props, default_return_skips=True, max_num_features=max_features)
+ self.decoder = PreActResidualUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
+ deep_supervision, upscale_logits)
+ if initializer is not None:
+ self.apply(initializer)
+
+ def forward(self, x):
+ skips = self.encoder(x)
+ return self.decoder(skips)
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, num_conv_per_stage_encoder,
+ num_conv_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
+ enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes,
+ num_conv_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size)
+ dec = PreActResidualUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes,
+ num_conv_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size)
+
+ return enc + dec
+
+ @staticmethod
+ def compute_reference_for_vram_consumption_3d():
+ patch_size = (128, 128, 128)
+ pool_op_kernel_sizes = ((1, 1, 1),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2))
+ blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1)
+ blocks_per_stage_decoder = (1, 1, 1, 1, 1)
+
+ return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes,
+ blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 2)
+
+ @staticmethod
+ def compute_reference_for_vram_consumption_2d():
+ patch_size = (256, 256)
+ pool_op_kernel_sizes = (
+ (1, 1), # (256, 256)
+ (2, 2), # (128, 128)
+ (2, 2), # (64, 64)
+ (2, 2), # (32, 32)
+ (2, 2), # (16, 16)
+ (2, 2), # (8, 8)
+ (2, 2) # (4, 4)
+ )
+ blocks_per_stage_encoder = (1, 1, 1, 1, 1, 1, 1)
+ blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1)
+
+ return PreActResidualUNet.compute_approx_vram_consumption(patch_size, 20, 512, 4, 3, pool_op_kernel_sizes,
+ blocks_per_stage_encoder, blocks_per_stage_decoder, 2, 50)
+
+
+class FabiansPreActUNet(SegmentationNetwork):
+ use_this_for_2D_configuration = 1792460800
+ use_this_for_3D_configuration = 1318592512
+ default_blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6)
+ default_blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2)
+ default_min_batch_size = 2 # this is what works with the numbers above
+
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
+ deep_supervision=False, upscale_logits=False, max_features=512, initializer=None):
+ super().__init__()
+ self.conv_op = props['conv_op']
+ self.num_classes = num_classes
+
+ self.encoder = PreActResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
+ props, default_return_skips=True, max_num_features=max_features)
+ props['dropout_op_kwargs']['p'] = 0
+ self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
+ deep_supervision, upscale_logits)
+
+ expected_num_skips = len(conv_kernel_sizes) - 1
+ num_features_skips = [min(max_features, base_num_features * 2**i) for i in range(expected_num_skips)]
+ norm_nonlins = []
+ for i in range(expected_num_skips):
+ norm_nonlins.append(nn.Sequential(props['norm_op'](num_features_skips[i], **props['norm_op_kwargs']), props['nonlin'](**props['nonlin_kwargs'])))
+ self.norm_nonlins = nn.ModuleList(norm_nonlins)
+
+ if initializer is not None:
+ self.apply(initializer)
+
+ def forward(self, x, gt=None, loss=None):
+ skips = self.encoder(x)
+ for i, op in enumerate(self.norm_nonlins):
+ skips[i] = self.norm_nonlins[i](skips[i])
+ return self.decoder(skips, gt, loss)
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, num_blocks_per_stage_encoder,
+ num_blocks_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
+ enc = PreActResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes,
+ num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size)
+ dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes,
+ num_blocks_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size)
+
+ return enc + dec
+
+
+def find_3d_configuration():
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ conv_op_kernel_sizes = ((3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3),
+ (3, 3, 3))
+ pool_op_kernel_sizes = ((1, 1, 1),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2),
+ (2, 2, 2))
+
+ patch_size = (128, 128, 128)
+ base_num_features = 32
+ input_modalities = 4
+ blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6)
+ blocks_per_stage_decoder = (2, 2, 2, 2, 2)
+ feat_map_mult_on_downscale = 2
+ num_classes = 5
+ max_features = 320
+ batch_size = 2
+
+ unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale,
+ pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(3, dropout_p=None), num_classes,
+ blocks_per_stage_decoder, True, False, max_features=max_features).cuda()
+
+ scaler = GradScaler()
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+
+ print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities,
+ num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder,
+ blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size))
+
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda()
+ dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long()
+
+ for i in range(10):
+ optimizer.zero_grad()
+
+ with autocast():
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ output = unet.decoder(skips)[0]
+
+ l = loss(output, dummy_gt)
+ print(l.item())
+ scaler.scale(l).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ with autocast():
+ import hiddenlayer as hl
+ g = hl.build_graph(unet, dummy_input, transforms=None)
+ g.save("/home/fabian/test_arch.pdf")
+
+
+def find_2d_configuration():
+ cudnn.benchmark = True
+ cudnn.deterministic = False
+
+ conv_op_kernel_sizes = ((3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3),
+ (3, 3))
+ pool_op_kernel_sizes = ((1, 1),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2),
+ (2, 2))
+
+ patch_size = (256, 256)
+ base_num_features = 32
+ input_modalities = 4
+ blocks_per_stage_encoder = (1, 3, 4, 6, 6, 6, 6)
+ blocks_per_stage_decoder = (2, 2, 2, 2, 2, 2)
+ feat_map_mult_on_downscale = 2
+ num_classes = 5
+ max_features = 512
+ batch_size = 50
+
+ unet = FabiansPreActUNet(input_modalities, base_num_features, blocks_per_stage_encoder, feat_map_mult_on_downscale,
+ pool_op_kernel_sizes, conv_op_kernel_sizes, get_default_network_config(2, dropout_p=None), num_classes,
+ blocks_per_stage_decoder, True, False, max_features=max_features).cuda()
+
+ scaler = GradScaler()
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+
+ print(unet.compute_approx_vram_consumption(patch_size, base_num_features, max_features, input_modalities,
+ num_classes, pool_op_kernel_sizes, blocks_per_stage_encoder,
+ blocks_per_stage_decoder, feat_map_mult_on_downscale, batch_size))
+
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ dummy_input = torch.rand((batch_size, input_modalities, *patch_size)).cuda()
+ dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, num_classes-1).cuda().long()
+
+ for i in range(10):
+ optimizer.zero_grad()
+
+ with autocast():
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ output = unet.decoder(skips)[0]
+
+ l = loss(output, dummy_gt)
+ print(l.item())
+ scaler.scale(l).backward()
+ scaler.step(optimizer)
+ scaler.update()
+
+ with autocast():
+ import hiddenlayer as hl
+ g = hl.build_graph(unet, dummy_input, transforms=None)
+ g.save("/home/fabian/test_arch.pdf")
diff --git a/nnunet/network_architecture/generic_modular_residual_UNet.py b/nnunet/network_architecture/generic_modular_residual_UNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a67e79a8af275902c368e9eea7f3bd2ad2aa111
--- /dev/null
+++ b/nnunet/network_architecture/generic_modular_residual_UNet.py
@@ -0,0 +1,509 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import torch
+from nnunet.network_architecture.custom_modules.conv_blocks import BasicResidualBlock, ResidualLayer
+from nnunet.network_architecture.generic_UNet import Upsample
+from nnunet.network_architecture.generic_modular_UNet import PlainConvUNetDecoder, get_default_network_config
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from torch import nn
+from torch.optim import SGD
+from torch.backends import cudnn
+
+
+class ResidualUNetEncoder(nn.Module):
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, default_return_skips=True,
+ max_num_features=480, block=BasicResidualBlock):
+ """
+ Following UNet building blocks can be added by utilizing the properties this class exposes (TODO)
+
+ this one includes the bottleneck layer!
+
+ :param input_channels:
+ :param base_num_features:
+ :param num_blocks_per_stage:
+ :param feat_map_mul_on_downscale:
+ :param pool_op_kernel_sizes:
+ :param conv_kernel_sizes:
+ :param props:
+ """
+ super(ResidualUNetEncoder, self).__init__()
+
+ self.default_return_skips = default_return_skips
+ self.props = props
+
+ self.stages = []
+ self.stage_output_features = []
+ self.stage_pool_kernel_size = []
+ self.stage_conv_op_kernel_size = []
+
+ assert len(pool_op_kernel_sizes) == len(conv_kernel_sizes)
+
+ num_stages = len(conv_kernel_sizes)
+
+ if not isinstance(num_blocks_per_stage, (list, tuple)):
+ num_blocks_per_stage = [num_blocks_per_stage] * num_stages
+ else:
+ assert len(num_blocks_per_stage) == num_stages
+
+ self.num_blocks_per_stage = num_blocks_per_stage # decoder may need this
+
+ self.initial_conv = props['conv_op'](input_channels, base_num_features, 3, padding=1, **props['conv_op_kwargs'])
+ self.initial_norm = props['norm_op'](base_num_features, **props['norm_op_kwargs'])
+ self.initial_nonlin = props['nonlin'](**props['nonlin_kwargs'])
+
+ current_input_features = base_num_features
+ for stage in range(num_stages):
+ current_output_features = min(base_num_features * feat_map_mul_on_downscale ** stage, max_num_features)
+ current_kernel_size = conv_kernel_sizes[stage]
+ current_pool_kernel_size = pool_op_kernel_sizes[stage]
+
+ current_stage = ResidualLayer(current_input_features, current_output_features, current_kernel_size, props,
+ self.num_blocks_per_stage[stage], current_pool_kernel_size, block)
+
+ self.stages.append(current_stage)
+ self.stage_output_features.append(current_output_features)
+ self.stage_conv_op_kernel_size.append(current_kernel_size)
+ self.stage_pool_kernel_size.append(current_pool_kernel_size)
+
+ # update current_input_features
+ current_input_features = current_output_features
+
+ self.stages = nn.ModuleList(self.stages)
+
+ def forward(self, x, return_skips=None):
+ """
+
+ :param x:
+ :param return_skips: if none then self.default_return_skips is used
+ :return:
+ """
+ skips = []
+
+ x = self.initial_nonlin(self.initial_norm(self.initial_conv(x)))
+ for s in self.stages:
+ x = s(x)
+ if self.default_return_skips:
+ skips.append(x)
+
+ if return_skips is None:
+ return_skips = self.default_return_skips
+
+ if return_skips:
+ return skips
+ else:
+ return x
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes, num_conv_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size):
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+
+ tmp = (num_conv_per_stage_encoder[0] * 2 + 1) * np.prod(current_shape) * base_num_features \
+ + num_modalities * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool + 1):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_conv_per_stage_encoder[p] * 2 + 1 # + 1 for conv in skip in first block
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+ return tmp * batch_size
+
+
+class ResidualUNetDecoder(nn.Module):
+ def __init__(self, previous, num_classes, num_blocks_per_stage=None, network_props=None, deep_supervision=False,
+ upscale_logits=False, block=BasicResidualBlock):
+ super(ResidualUNetDecoder, self).__init__()
+ self.num_classes = num_classes
+ self.deep_supervision = deep_supervision
+ """
+ We assume the bottleneck is part of the encoder, so we can start with upsample -> concat here
+ """
+ previous_stages = previous.stages
+ previous_stage_output_features = previous.stage_output_features
+ previous_stage_pool_kernel_size = previous.stage_pool_kernel_size
+ previous_stage_conv_op_kernel_size = previous.stage_conv_op_kernel_size
+
+ if network_props is None:
+ self.props = previous.props
+ else:
+ self.props = network_props
+
+ if self.props['conv_op'] == nn.Conv2d:
+ transpconv = nn.ConvTranspose2d
+ upsample_mode = "bilinear"
+ elif self.props['conv_op'] == nn.Conv3d:
+ transpconv = nn.ConvTranspose3d
+ upsample_mode = "trilinear"
+ else:
+ raise ValueError("unknown convolution dimensionality, conv op: %s" % str(self.props['conv_op']))
+
+ if num_blocks_per_stage is None:
+ num_blocks_per_stage = previous.num_blocks_per_stage[:-1][::-1]
+
+ assert len(num_blocks_per_stage) == len(previous.num_blocks_per_stage) - 1
+
+ self.stage_pool_kernel_size = previous_stage_pool_kernel_size
+ self.stage_output_features = previous_stage_output_features
+ self.stage_conv_op_kernel_size = previous_stage_conv_op_kernel_size
+
+ num_stages = len(previous_stages) - 1 # we have one less as the first stage here is what comes after the
+ # bottleneck
+
+ self.tus = []
+ self.stages = []
+ self.deep_supervision_outputs = []
+
+ # only used for upsample_logits
+ cum_upsample = np.cumprod(np.vstack(self.stage_pool_kernel_size), axis=0).astype(int)
+
+ for i, s in enumerate(np.arange(num_stages)[::-1]):
+ features_below = previous_stage_output_features[s + 1]
+ features_skip = previous_stage_output_features[s]
+
+ self.tus.append(transpconv(features_below, features_skip, previous_stage_pool_kernel_size[s + 1],
+ previous_stage_pool_kernel_size[s + 1], bias=False))
+ # after we tu we concat features so now we have 2xfeatures_skip
+ self.stages.append(ResidualLayer(2 * features_skip, features_skip, previous_stage_conv_op_kernel_size[s],
+ self.props, num_blocks_per_stage[i], None, block))
+
+ if deep_supervision and s != 0:
+ seg_layer = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
+ if upscale_logits:
+ upsample = Upsample(scale_factor=cum_upsample[s], mode=upsample_mode)
+ self.deep_supervision_outputs.append(nn.Sequential(seg_layer, upsample))
+ else:
+ self.deep_supervision_outputs.append(seg_layer)
+
+ self.segmentation_output = self.props['conv_op'](features_skip, num_classes, 1, 1, 0, 1, 1, False)
+
+ self.tus = nn.ModuleList(self.tus)
+ self.stages = nn.ModuleList(self.stages)
+ self.deep_supervision_outputs = nn.ModuleList(self.deep_supervision_outputs)
+
+ def forward(self, skips):
+ # skips come from the encoder. They are sorted so that the bottleneck is last in the list
+ # what is maybe not perfect is that the TUs and stages here are sorted the other way around
+ # so let's just reverse the order of skips
+ skips = skips[::-1]
+ seg_outputs = []
+
+ x = skips[0] # this is the bottleneck
+
+ for i in range(len(self.tus)):
+ x = self.tus[i](x)
+ x = torch.cat((x, skips[i + 1]), dim=1)
+ x = self.stages[i](x)
+ if self.deep_supervision and (i != len(self.tus) - 1):
+ seg_outputs.append(self.deep_supervision_outputs[i](x))
+
+ segmentation = self.segmentation_output(x)
+
+ if self.deep_supervision:
+ seg_outputs.append(segmentation)
+ return seg_outputs[
+ ::-1] # seg_outputs are ordered so that the seg from the highest layer is first, the seg from
+ # the bottleneck of the UNet last
+ else:
+ return segmentation
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes, num_blocks_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size):
+ """
+ This only applies for num_conv_per_stage and convolutional_upsampling=True
+ not real vram consumption. just a constant term to which the vram consumption will be approx proportional
+ (+ offset for parameter storage)
+ :param patch_size:
+ :param num_pool_per_axis:
+ :param base_num_features:
+ :param max_num_features:
+ :return:
+ """
+ npool = len(pool_op_kernel_sizes) - 1
+
+ current_shape = np.array(patch_size)
+ tmp = (num_blocks_per_stage_decoder[-1] * 2 + 1) * np.prod(
+ current_shape) * base_num_features + num_classes * np.prod(current_shape)
+
+ num_feat = base_num_features
+
+ for p in range(1, npool):
+ current_shape = current_shape / np.array(pool_op_kernel_sizes[p])
+ num_feat = min(num_feat * feat_map_mul_on_downscale, max_num_features)
+ num_convs = num_blocks_per_stage_decoder[-(p + 1)] * 2 + 1 + 1 # +1 for transpconv and +1 for conv in skip
+ print(p, num_feat, num_convs, current_shape)
+ tmp += num_convs * np.prod(current_shape) * num_feat
+
+ return tmp * batch_size
+
+
+class ResidualUNet(SegmentationNetwork):
+ use_this_for_batch_size_computation_2D = 858931200.0 # 1167982592.0
+ use_this_for_batch_size_computation_3D = 727842816.0 # 1152286720.0
+ default_base_num_features = 24
+ default_conv_per_stage = (2, 2, 2, 2, 2, 2, 2, 2)
+
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
+ deep_supervision=False, upscale_logits=False, max_features=512, initializer=None,
+ block=BasicResidualBlock):
+ super(ResidualUNet, self).__init__()
+ self.conv_op = props['conv_op']
+ self.num_classes = num_classes
+
+ self.encoder = ResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
+ props, default_return_skips=True, max_num_features=max_features, block=block)
+ self.decoder = ResidualUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props,
+ deep_supervision, upscale_logits, block=block)
+ if initializer is not None:
+ self.apply(initializer)
+
+ def forward(self, x):
+ skips = self.encoder(x)
+ return self.decoder(skips)
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, num_conv_per_stage_encoder,
+ num_conv_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
+ enc = ResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes,
+ num_conv_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size)
+ dec = ResidualUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes,
+ num_conv_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size)
+
+ return enc + dec
+
+
+class FabiansUNet(SegmentationNetwork):
+ """
+ Residual Encoder, Plain conv decoder
+ """
+ use_this_for_2D_configuration = 1244233721.0 # 1167982592.0
+ use_this_for_3D_configuration = 1230348801.0
+ default_blocks_per_stage_encoder = (1, 2, 3, 4, 4, 4, 4, 4, 4, 4, 4)
+ default_blocks_per_stage_decoder = (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
+ default_min_batch_size = 2 # this is what works with the numbers above
+
+ def __init__(self, input_channels, base_num_features, num_blocks_per_stage_encoder, feat_map_mul_on_downscale,
+ pool_op_kernel_sizes, conv_kernel_sizes, props, num_classes, num_blocks_per_stage_decoder,
+ deep_supervision=False, upscale_logits=False, max_features=512, initializer=None,
+ block=BasicResidualBlock,
+ props_decoder=None):
+ super().__init__()
+ self.conv_op = props['conv_op']
+ self.num_classes = num_classes
+
+ self.encoder = ResidualUNetEncoder(input_channels, base_num_features, num_blocks_per_stage_encoder,
+ feat_map_mul_on_downscale, pool_op_kernel_sizes, conv_kernel_sizes,
+ props, default_return_skips=True, max_num_features=max_features, block=block)
+ props['dropout_op_kwargs']['p'] = 0
+ if props_decoder is None:
+ props_decoder = props
+ self.decoder = PlainConvUNetDecoder(self.encoder, num_classes, num_blocks_per_stage_decoder, props_decoder,
+ deep_supervision, upscale_logits)
+ if initializer is not None:
+ self.apply(initializer)
+
+ def forward(self, x):
+ skips = self.encoder(x)
+ return self.decoder(skips)
+
+ @staticmethod
+ def compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, num_classes, pool_op_kernel_sizes, num_conv_per_stage_encoder,
+ num_conv_per_stage_decoder, feat_map_mul_on_downscale, batch_size):
+ enc = ResidualUNetEncoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_modalities, pool_op_kernel_sizes,
+ num_conv_per_stage_encoder,
+ feat_map_mul_on_downscale, batch_size)
+ dec = PlainConvUNetDecoder.compute_approx_vram_consumption(patch_size, base_num_features, max_num_features,
+ num_classes, pool_op_kernel_sizes,
+ num_conv_per_stage_decoder,
+ feat_map_mul_on_downscale, batch_size)
+
+ return enc + dec
+
+
+def find_3d_configuration():
+ # lets compute a reference for 3D
+ # we select hyperparameters here so that we get approximately the same patch size as we would get with the
+ # regular unet. This is just my choice. You can do whatever you want
+ # These default hyperparemeters will then be used by the experiment planner
+
+ # since this is more parameter intensive than the UNet, we will test a configuration that has a lot of parameters
+ # herefore we copy the UNet configuration for Task005_Prostate
+ cudnn.deterministic = False
+ cudnn.benchmark = True
+
+ patch_size = (20, 320, 256)
+ max_num_features = 320
+ num_modalities = 2
+ num_classes = 3
+ batch_size = 2
+
+ # now we fiddle with the network specific hyperparameters until everything just barely fits into a titanx
+ blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder
+ blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder
+ initial_num_features = 32
+
+ # we neeed to add a [1, 1, 1] for the res unet because in this implementation all stages of the encoder can have a stride
+ pool_op_kernel_sizes = [[1, 1, 1],
+ [1, 2, 2],
+ [1, 2, 2],
+ [2, 2, 2],
+ [2, 2, 2],
+ [1, 2, 2],
+ [1, 2, 2]]
+
+ conv_op_kernel_sizes = [[1, 3, 3],
+ [1, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3],
+ [3, 3, 3]]
+
+ unet = FabiansUNet(num_modalities, initial_num_features, blocks_per_stage_encoder[:len(conv_op_kernel_sizes)], 2,
+ pool_op_kernel_sizes, conv_op_kernel_sizes,
+ get_default_network_config(3, dropout_p=None), num_classes,
+ blocks_per_stage_decoder[:len(conv_op_kernel_sizes)-1], False, False,
+ max_features=max_num_features).cuda()
+
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ dummy_input = torch.rand((batch_size, num_modalities, *patch_size)).cuda()
+ dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, 2).cuda().long()
+
+ for _ in range(20):
+ optimizer.zero_grad()
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ output = unet.decoder(skips)
+
+ l = loss(output, dummy_gt)
+ l.backward()
+
+ optimizer.step()
+ if _ == 0:
+ torch.cuda.empty_cache()
+
+ # that should do. Now take the network hyperparameters and insert them in FabiansUNet.compute_approx_vram_consumption
+ # whatever number this spits out, save it to FabiansUNet.use_this_for_batch_size_computation_3D
+ print(FabiansUNet.compute_approx_vram_consumption(patch_size, initial_num_features, max_num_features, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ blocks_per_stage_encoder[:len(conv_op_kernel_sizes)],
+ blocks_per_stage_decoder[:len(conv_op_kernel_sizes)-1], 2, batch_size))
+ # the output is 1230348800.0 for me
+ # I increment that number by 1 to allow this configuration be be chosen
+
+
+def find_2d_configuration():
+ # lets compute a reference for 3D
+ # we select hyperparameters here so that we get approximately the same patch size as we would get with the
+ # regular unet. This is just my choice. You can do whatever you want
+ # These default hyperparemeters will then be used by the experiment planner
+
+ # since this is more parameter intensive than the UNet, we will test a configuration that has a lot of parameters
+ # herefore we copy the UNet configuration for Task003_Liver
+ cudnn.deterministic = False
+ cudnn.benchmark = True
+
+ patch_size = (512, 512)
+ max_num_features = 512
+ num_modalities = 1
+ num_classes = 3
+ batch_size = 12
+
+ # now we fiddle with the network specific hyperparameters until everything just barely fits into a titanx
+ blocks_per_stage_encoder = FabiansUNet.default_blocks_per_stage_encoder
+ blocks_per_stage_decoder = FabiansUNet.default_blocks_per_stage_decoder
+ initial_num_features = 30
+
+ # we neeed to add a [1, 1, 1] for the res unet because in this implementation all stages of the encoder can have a stride
+ pool_op_kernel_sizes = [[1, 1],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [2, 2],
+ [2, 2]]
+
+ conv_op_kernel_sizes = [[3, 3],
+ [3, 3],
+ [3, 3],
+ [3, 3],
+ [3, 3],
+ [3, 3],
+ [3, 3],
+ [3, 3]]
+
+ unet = FabiansUNet(num_modalities, initial_num_features, blocks_per_stage_encoder[:len(conv_op_kernel_sizes)], 2,
+ pool_op_kernel_sizes, conv_op_kernel_sizes,
+ get_default_network_config(2, dropout_p=None), num_classes,
+ blocks_per_stage_decoder[:len(conv_op_kernel_sizes)-1], False, False,
+ max_features=max_num_features).cuda()
+
+ optimizer = SGD(unet.parameters(), lr=0.1, momentum=0.95)
+ loss = DC_and_CE_loss({'batch_dice': True, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ dummy_input = torch.rand((batch_size, num_modalities, *patch_size)).cuda()
+ dummy_gt = (torch.rand((batch_size, 1, *patch_size)) * num_classes).round().clamp_(0, 2).cuda().long()
+
+ for _ in range(20):
+ optimizer.zero_grad()
+ skips = unet.encoder(dummy_input)
+ print([i.shape for i in skips])
+ output = unet.decoder(skips)
+
+ l = loss(output, dummy_gt)
+ l.backward()
+
+ optimizer.step()
+ if _ == 0:
+ torch.cuda.empty_cache()
+
+ # that should do. Now take the network hyperparameters and insert them in FabiansUNet.compute_approx_vram_consumption
+ # whatever number this spits out, save it to FabiansUNet.use_this_for_batch_size_computation_2D
+ print(FabiansUNet.compute_approx_vram_consumption(patch_size, initial_num_features, max_num_features, num_modalities,
+ num_classes, pool_op_kernel_sizes,
+ blocks_per_stage_encoder[:len(conv_op_kernel_sizes)],
+ blocks_per_stage_decoder[:len(conv_op_kernel_sizes)-1], 2, batch_size))
+ # the output is 1244233728.0 for me
+ # I increment that number by 1 to allow this configuration be be chosen
+ # This will not fit with 32 filters, but so will the regular U-net. We still use 32 filters in training.
+ # This does not matter because we are using mixed precision training now, so a rough memory approximation is OK
+
+
+if __name__ == "__main__":
+ pass
+
diff --git a/nnunet/network_architecture/initialization.py b/nnunet/network_architecture/initialization.py
new file mode 100644
index 0000000000000000000000000000000000000000..901c4b132d23b794e3f0137d61ea963a4aeddabb
--- /dev/null
+++ b/nnunet/network_architecture/initialization.py
@@ -0,0 +1,38 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from torch import nn
+
+
+class InitWeights_He(object):
+ def __init__(self, neg_slope=1e-2):
+ self.neg_slope = neg_slope
+
+ def __call__(self, module):
+ if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
+ module.weight = nn.init.kaiming_normal_(module.weight, a=self.neg_slope)
+ if module.bias is not None:
+ module.bias = nn.init.constant_(module.bias, 0)
+
+
+class InitWeights_XavierUniform(object):
+ def __init__(self, gain=1):
+ self.gain = gain
+
+ def __call__(self, module):
+ if isinstance(module, nn.Conv3d) or isinstance(module, nn.Conv2d) or isinstance(module, nn.ConvTranspose2d) or isinstance(module, nn.ConvTranspose3d):
+ module.weight = nn.init.xavier_uniform_(module.weight, self.gain)
+ if module.bias is not None:
+ module.bias = nn.init.constant_(module.bias, 0)
diff --git a/nnunet/network_architecture/neural_network.py b/nnunet/network_architecture/neural_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cd69dbbfa3d49c66d19fa23b6f8aa6c26402783
--- /dev/null
+++ b/nnunet/network_architecture/neural_network.py
@@ -0,0 +1,845 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.augmentations.utils import pad_nd_image
+from nnunet.utilities.random_stuff import no_op
+from nnunet.utilities.to_torch import to_cuda, maybe_to_torch
+from torch import nn
+import torch
+from scipy.ndimage.filters import gaussian_filter
+from typing import Union, Tuple, List
+
+from torch.cuda.amp import autocast
+
+
+class NeuralNetwork(nn.Module):
+ def __init__(self):
+ super(NeuralNetwork, self).__init__()
+
+ def get_device(self):
+ if next(self.parameters()).device.type == "cpu":
+ return "cpu"
+ else:
+ return next(self.parameters()).device.index
+
+ def set_device(self, device):
+ if device == "cpu":
+ self.cpu()
+ else:
+ self.cuda(device)
+
+ def forward(self, x):
+ raise NotImplementedError
+
+
+class SegmentationNetwork(NeuralNetwork):
+ def __init__(self):
+ super(NeuralNetwork, self).__init__()
+
+ # if we have 5 pooling then our patch size must be divisible by 2**5
+ self.input_shape_must_be_divisible_by = None # for example in a 2d network that does 5 pool in x and 6 pool
+ # in y this would be (32, 64)
+
+ # we need to know this because we need to know if we are a 2d or a 3d netowrk
+ self.conv_op = None # nn.Conv2d or nn.Conv3d
+
+ # this tells us how many channels we have in the output. Important for preallocation in inference
+ self.num_classes = None # number of channels in the output
+
+ # depending on the loss, we do not hard code a nonlinearity into the architecture. To aggregate predictions
+ # during inference, we need to apply the nonlinearity, however. So it is important to let the newtork know what
+ # to apply in inference. For the most part this will be softmax
+ self.inference_apply_nonlin = lambda x: x # softmax_helper
+
+ # This is for saving a gaussian importance map for inference. It weights voxels higher that are closer to the
+ # center. Prediction at the borders are often less accurate and are thus downweighted. Creating these Gaussians
+ # can be expensive, so it makes sense to save and reuse them.
+ self._gaussian_3d = self._patch_size_for_gaussian_3d = None
+ self._gaussian_2d = self._patch_size_for_gaussian_2d = None
+
+ def predict_3D(self, x: np.ndarray, do_mirroring: bool, mirror_axes: Tuple[int, ...] = (0, 1, 2),
+ use_sliding_window: bool = False,
+ step_size: float = 0.5, patch_size: Tuple[int, ...] = None, regions_class_order: Tuple[int, ...] = None,
+ use_gaussian: bool = False, pad_border_mode: str = "constant",
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Use this function to predict a 3D image. It does not matter whether the network is a 2D or 3D U-Net, it will
+ detect that automatically and run the appropriate code.
+
+ When running predictions, you need to specify whether you want to run fully convolutional of sliding window
+ based inference. We very strongly recommend you use sliding window with the default settings.
+
+ It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If
+ the network is not in eval mode it will print a warning.
+
+ :param x: Your input data. Must be a nd.ndarray of shape (c, x, y, z).
+ :param do_mirroring: If True, use test time data augmentation in the form of mirroring
+ :param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three
+ axes
+ :param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default
+ :param step_size: When running sliding window prediction, the step size determines the distance between adjacent
+ predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given
+ as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between
+ predictions. step_size cannot be larger than 1!
+ :param patch_size: The patch size that was used for training the network. Do not use different patch sizes here,
+ this will either crash or give potentially less accurate segmentations
+ :param regions_class_order: Fabian only
+ :param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting
+ to weigh predictions closer to the center of the current patch higher than those at the borders. The reason
+ behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True
+ :param pad_border_mode: leave this alone
+ :param pad_kwargs: leave this alone
+ :param all_in_gpu: experimental. You probably want to leave this as is it
+ :param verbose: Do you want a wall of text? If yes then set this to True
+ :param mixed_precision: if True, will run inference in mixed precision with autocast()
+ :return:
+ """
+ torch.cuda.empty_cache()
+
+ assert step_size <= 1, 'step_size must be smaller than 1. Otherwise there will be a gap between consecutive ' \
+ 'predictions'
+
+ if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
+
+ if pad_kwargs is None:
+ pad_kwargs = {'constant_values': 0}
+
+ # A very long time ago the mirror axes were (2, 3, 4) for a 3d network. This is just to intercept any old
+ # code that uses this convention
+ if len(mirror_axes):
+ if self.conv_op == nn.Conv2d:
+ if max(mirror_axes) > 1:
+ raise ValueError("mirror axes. duh")
+ if self.conv_op == nn.Conv3d:
+ if max(mirror_axes) > 2:
+ raise ValueError("mirror axes. duh")
+
+ if self.training:
+ print('WARNING! Network is in train mode during inference. This may be intended, or not...')
+
+ assert len(x.shape) == 4, "data must have shape (c,x,y,z)"
+
+ if mixed_precision:
+ context = autocast
+ else:
+ context = no_op
+
+ with context():
+ with torch.no_grad():
+ if self.conv_op == nn.Conv3d:
+ if use_sliding_window:
+ res = self._internal_predict_3D_3Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
+ regions_class_order, use_gaussian, pad_border_mode,
+ pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
+ verbose=verbose)
+ else:
+ res = self._internal_predict_3D_3Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
+ pad_border_mode, pad_kwargs=pad_kwargs, verbose=verbose)
+ elif self.conv_op == nn.Conv2d:
+ if use_sliding_window:
+ res = self._internal_predict_3D_2Dconv_tiled(x, patch_size, do_mirroring, mirror_axes, step_size,
+ regions_class_order, use_gaussian, pad_border_mode,
+ pad_kwargs, all_in_gpu, False)
+ else:
+ res = self._internal_predict_3D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
+ pad_border_mode, pad_kwargs, all_in_gpu, False)
+ else:
+ raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
+
+ return res
+
+ def predict_2D(self, x, do_mirroring: bool, mirror_axes: tuple = (0, 1, 2), use_sliding_window: bool = False,
+ step_size: float = 0.5, patch_size: tuple = None, regions_class_order: tuple = None,
+ use_gaussian: bool = False, pad_border_mode: str = "constant",
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Use this function to predict a 2D image. If this is a 3D U-Net it will crash because you cannot predict a 2D
+ image with that (you dummy).
+
+ When running predictions, you need to specify whether you want to run fully convolutional of sliding window
+ based inference. We very strongly recommend you use sliding window with the default settings.
+
+ It is the responsibility of the user to make sure the network is in the proper mode (eval for inference!). If
+ the network is not in eval mode it will print a warning.
+
+ :param x: Your input data. Must be a nd.ndarray of shape (c, x, y).
+ :param do_mirroring: If True, use test time data augmentation in the form of mirroring
+ :param mirror_axes: Determines which axes to use for mirroing. Per default, mirroring is done along all three
+ axes
+ :param use_sliding_window: if True, run sliding window prediction. Heavily recommended! This is also the default
+ :param step_size: When running sliding window prediction, the step size determines the distance between adjacent
+ predictions. The smaller the step size, the denser the predictions (and the longer it takes!). Step size is given
+ as a fraction of the patch_size. 0.5 is the default and means that wen advance by patch_size * 0.5 between
+ predictions. step_size cannot be larger than 1!
+ :param patch_size: The patch size that was used for training the network. Do not use different patch sizes here,
+ this will either crash or give potentially less accurate segmentations
+ :param regions_class_order: Fabian only
+ :param use_gaussian: (Only applies to sliding window prediction) If True, uses a Gaussian importance weighting
+ to weigh predictions closer to the center of the current patch higher than those at the borders. The reason
+ behind this is that the segmentation accuracy decreases towards the borders. Default (and recommended): True
+ :param pad_border_mode: leave this alone
+ :param pad_kwargs: leave this alone
+ :param all_in_gpu: experimental. You probably want to leave this as is it
+ :param verbose: Do you want a wall of text? If yes then set this to True
+ :return:
+ """
+ torch.cuda.empty_cache()
+
+ assert step_size <= 1, 'step_size must be smaler than 1. Otherwise there will be a gap between consecutive ' \
+ 'predictions'
+
+ if self.conv_op == nn.Conv3d:
+ raise RuntimeError("Cannot predict 2d if the network is 3d. Dummy.")
+
+ if verbose: print("debug: mirroring", do_mirroring, "mirror_axes", mirror_axes)
+
+ if pad_kwargs is None:
+ pad_kwargs = {'constant_values': 0}
+
+ # A very long time ago the mirror axes were (2, 3) for a 2d network. This is just to intercept any old
+ # code that uses this convention
+ if len(mirror_axes):
+ if max(mirror_axes) > 1:
+ raise ValueError("mirror axes. duh")
+
+ if self.training:
+ print('WARNING! Network is in train mode during inference. This may be intended, or not...')
+
+ assert len(x.shape) == 3, "data must have shape (c,x,y)"
+
+ if mixed_precision:
+ context = autocast
+ else:
+ context = no_op
+
+ with context():
+ with torch.no_grad():
+ if self.conv_op == nn.Conv2d:
+ if use_sliding_window:
+ res = self._internal_predict_2D_2Dconv_tiled(x, step_size, do_mirroring, mirror_axes, patch_size,
+ regions_class_order, use_gaussian, pad_border_mode,
+ pad_kwargs, all_in_gpu, verbose)
+ else:
+ res = self._internal_predict_2D_2Dconv(x, patch_size, do_mirroring, mirror_axes, regions_class_order,
+ pad_border_mode, pad_kwargs, verbose)
+ else:
+ raise RuntimeError("Invalid conv op, cannot determine what dimensionality (2d/3d) the network is")
+
+ return res
+
+ @staticmethod
+ def _get_gaussian(patch_size, sigma_scale=1. / 8) -> np.ndarray:
+ tmp = np.zeros(patch_size)
+ center_coords = [i // 2 for i in patch_size]
+ sigmas = [i * sigma_scale for i in patch_size]
+ tmp[tuple(center_coords)] = 1
+ gaussian_importance_map = gaussian_filter(tmp, sigmas, 0, mode='constant', cval=0)
+ gaussian_importance_map = gaussian_importance_map / np.max(gaussian_importance_map) * 1
+ gaussian_importance_map = gaussian_importance_map.astype(np.float32)
+
+ # gaussian_importance_map cannot be 0, otherwise we may end up with nans!
+ gaussian_importance_map[gaussian_importance_map == 0] = np.min(
+ gaussian_importance_map[gaussian_importance_map != 0])
+
+ return gaussian_importance_map
+
+ @staticmethod
+ def _compute_steps_for_sliding_window(patch_size: Tuple[int, ...], image_size: Tuple[int, ...], step_size: float) -> List[List[int]]:
+ assert [i >= j for i, j in zip(image_size, patch_size)], "image size must be as large or larger than patch_size"
+ assert 0 < step_size <= 1, 'step_size must be larger than 0 and smaller or equal to 1'
+
+ # our step width is patch_size*step_size at most, but can be narrower. For example if we have image size of
+ # 110, patch size of 64 and step_size of 0.5, then we want to make 3 steps starting at coordinate 0, 23, 46
+ target_step_sizes_in_voxels = [i * step_size for i in patch_size]
+
+ num_steps = [int(np.ceil((i - k) / j)) + 1 for i, j, k in zip(image_size, target_step_sizes_in_voxels, patch_size)]
+
+ steps = []
+ for dim in range(len(patch_size)):
+ # the highest step value for this dimension is
+ max_step_value = image_size[dim] - patch_size[dim]
+ if num_steps[dim] > 1:
+ actual_step_size = max_step_value / (num_steps[dim] - 1)
+ else:
+ actual_step_size = 99999999999 # does not matter because there is only one step at 0
+
+ steps_here = [int(np.round(actual_step_size * i)) for i in range(num_steps[dim])]
+
+ steps.append(steps_here)
+
+ return steps
+
+ def _internal_predict_3D_3Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
+ patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
+ pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
+ verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
+ # better safe than sorry
+ assert len(x.shape) == 4, "x must be (c, x, y, z)"
+
+ if verbose: print("step_size:", step_size)
+ if verbose: print("do mirror:", do_mirroring)
+
+ assert patch_size is not None, "patch_size cannot be None for tiled prediction"
+
+ # for sliding window inference the image must at least be as large as the patch size. It does not matter
+ # whether the shape is divisible by 2**num_pool as long as the patch size is
+ data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
+ data_shape = data.shape # still c, x, y, z
+
+ # compute the steps for sliding window
+ steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
+ num_tiles = len(steps[0]) * len(steps[1]) * len(steps[2])
+
+ if verbose:
+ print("data shape:", data_shape)
+ print("patch size:", patch_size)
+ print("steps (x, y, and z):", steps)
+ print("number of tiles:", num_tiles)
+
+ # we only need to compute that once. It can take a while to compute this due to the large sigma in
+ # gaussian_filter
+ if use_gaussian and num_tiles > 1:
+ if self._gaussian_3d is None or not all(
+ [i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_3d)]):
+ if verbose: print('computing Gaussian')
+ gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
+
+ self._gaussian_3d = gaussian_importance_map
+ self._patch_size_for_gaussian_3d = patch_size
+ if verbose: print("done")
+ else:
+ if verbose: print("using precomputed Gaussian")
+ gaussian_importance_map = self._gaussian_3d
+
+ gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
+
+ #predict on cpu if cuda not available
+ if torch.cuda.is_available():
+ gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
+
+ else:
+ gaussian_importance_map = None
+
+ if all_in_gpu:
+ # If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
+ # CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
+
+ if use_gaussian and num_tiles > 1:
+ # half precision for the outputs should be good enough. If the outputs here are half, the
+ # gaussian_importance_map should be as well
+ gaussian_importance_map = gaussian_importance_map.half()
+
+ # make sure we did not round anything to 0
+ gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
+ gaussian_importance_map != 0].min()
+
+ add_for_nb_of_preds = gaussian_importance_map
+ else:
+ add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
+
+ if verbose: print("initializing result array (on GPU)")
+ aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
+ device=self.get_device())
+
+ if verbose: print("moving data to GPU")
+ data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
+
+ if verbose: print("initializing result_numsamples (on GPU)")
+ aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
+ device=self.get_device())
+
+ else:
+ if use_gaussian and num_tiles > 1:
+ add_for_nb_of_preds = self._gaussian_3d
+ else:
+ add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
+ aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
+ aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
+
+ for x in steps[0]:
+ lb_x = x
+ ub_x = x + patch_size[0]
+ for y in steps[1]:
+ lb_y = y
+ ub_y = y + patch_size[1]
+ for z in steps[2]:
+ lb_z = z
+ ub_z = z + patch_size[2]
+
+ predicted_patch = self._internal_maybe_mirror_and_pred_3D(
+ data[None, :, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z], mirror_axes, do_mirroring,
+ gaussian_importance_map)[0]
+
+ if all_in_gpu:
+ predicted_patch = predicted_patch.half()
+ else:
+ predicted_patch = predicted_patch.cpu().numpy()
+
+ aggregated_results[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += predicted_patch
+ aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y, lb_z:ub_z] += add_for_nb_of_preds
+
+ # we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
+ slicer = tuple(
+ [slice(0, aggregated_results.shape[i]) for i in
+ range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
+ aggregated_results = aggregated_results[slicer]
+ aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
+
+ # computing the class_probabilities by dividing the aggregated result with result_numsamples
+ aggregated_results /= aggregated_nb_of_predictions
+ del aggregated_nb_of_predictions
+
+ if regions_class_order is None:
+ predicted_segmentation = aggregated_results.argmax(0)
+ else:
+ if all_in_gpu:
+ class_probabilities_here = aggregated_results.detach().cpu().numpy()
+ else:
+ class_probabilities_here = aggregated_results
+ predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
+ for i, c in enumerate(regions_class_order):
+ predicted_segmentation[class_probabilities_here[i] > 0.5] = c
+
+ if all_in_gpu:
+ if verbose: print("copying results to CPU")
+
+ if regions_class_order is None:
+ predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
+
+ aggregated_results = aggregated_results.detach().cpu().numpy()
+
+ if verbose: print("prediction done")
+ return predicted_segmentation, aggregated_results
+
+ def _internal_predict_2D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
+ mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
+ pad_border_mode: str = "constant", pad_kwargs: dict = None,
+ verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ This one does fully convolutional inference. No sliding window
+ """
+ assert len(x.shape) == 3, "x must be (c, x, y)"
+
+ assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
+ 'run _internal_predict_2D_2Dconv'
+ if verbose: print("do mirror:", do_mirroring)
+
+ data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
+ self.input_shape_must_be_divisible_by)
+
+ predicted_probabilities = self._internal_maybe_mirror_and_pred_2D(data[None], mirror_axes, do_mirroring,
+ None)[0]
+
+ slicer = tuple(
+ [slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
+ (len(slicer) - 1))] + slicer[1:])
+ predicted_probabilities = predicted_probabilities[slicer]
+
+ if regions_class_order is None:
+ predicted_segmentation = predicted_probabilities.argmax(0)
+ predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
+ predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
+ else:
+ predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
+ predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
+ for i, c in enumerate(regions_class_order):
+ predicted_segmentation[predicted_probabilities[i] > 0.5] = c
+
+ return predicted_segmentation, predicted_probabilities
+
+ def _internal_predict_3D_3Dconv(self, x: np.ndarray, min_size: Tuple[int, ...], do_mirroring: bool,
+ mirror_axes: tuple = (0, 1, 2), regions_class_order: tuple = None,
+ pad_border_mode: str = "constant", pad_kwargs: dict = None,
+ verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ This one does fully convolutional inference. No sliding window
+ """
+ assert len(x.shape) == 4, "x must be (c, x, y, z)"
+
+ assert self.input_shape_must_be_divisible_by is not None, 'input_shape_must_be_divisible_by must be set to ' \
+ 'run _internal_predict_3D_3Dconv'
+ if verbose: print("do mirror:", do_mirroring)
+
+ data, slicer = pad_nd_image(x, min_size, pad_border_mode, pad_kwargs, True,
+ self.input_shape_must_be_divisible_by)
+
+ predicted_probabilities = self._internal_maybe_mirror_and_pred_3D(data[None], mirror_axes, do_mirroring,
+ None)[0]
+
+ slicer = tuple(
+ [slice(0, predicted_probabilities.shape[i]) for i in range(len(predicted_probabilities.shape) -
+ (len(slicer) - 1))] + slicer[1:])
+ predicted_probabilities = predicted_probabilities[slicer]
+
+ if regions_class_order is None:
+ predicted_segmentation = predicted_probabilities.argmax(0)
+ predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
+ predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
+ else:
+ predicted_probabilities = predicted_probabilities.detach().cpu().numpy()
+ predicted_segmentation = np.zeros(predicted_probabilities.shape[1:], dtype=np.float32)
+ for i, c in enumerate(regions_class_order):
+ predicted_segmentation[predicted_probabilities[i] > 0.5] = c
+
+ return predicted_segmentation, predicted_probabilities
+
+ def _internal_maybe_mirror_and_pred_3D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
+ do_mirroring: bool = True,
+ mult: np.ndarray or torch.tensor = None) -> torch.tensor:
+ assert len(x.shape) == 5, 'x must be (b, c, x, y, z)'
+
+ # if cuda available:
+ # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
+ # we now return a cuda tensor! Not numpy array!
+
+ x = maybe_to_torch(x)
+ result_torch = torch.zeros([1, self.num_classes] + list(x.shape[2:]),
+ dtype=torch.float)
+
+ if torch.cuda.is_available():
+ x = to_cuda(x, gpu_id=self.get_device())
+ result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
+
+ if mult is not None:
+ mult = maybe_to_torch(mult)
+ if torch.cuda.is_available():
+ mult = to_cuda(mult, gpu_id=self.get_device())
+
+ if do_mirroring:
+ mirror_idx = 8
+ num_results = 2 ** len(mirror_axes)
+ else:
+ mirror_idx = 1
+ num_results = 1
+
+ for m in range(mirror_idx):
+ if m == 0:
+ pred = self.inference_apply_nonlin(self(x))
+ result_torch += 1 / num_results * pred
+
+ if m == 1 and (2 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (4, ))))
+ result_torch += 1 / num_results * torch.flip(pred, (4,))
+
+ if m == 2 and (1 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
+ result_torch += 1 / num_results * torch.flip(pred, (3,))
+
+ if m == 3 and (2 in mirror_axes) and (1 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3))))
+ result_torch += 1 / num_results * torch.flip(pred, (4, 3))
+
+ if m == 4 and (0 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
+ result_torch += 1 / num_results * torch.flip(pred, (2,))
+
+ if m == 5 and (0 in mirror_axes) and (2 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 2))))
+ result_torch += 1 / num_results * torch.flip(pred, (4, 2))
+
+ if m == 6 and (0 in mirror_axes) and (1 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
+ result_torch += 1 / num_results * torch.flip(pred, (3, 2))
+
+ if m == 7 and (0 in mirror_axes) and (1 in mirror_axes) and (2 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (4, 3, 2))))
+ result_torch += 1 / num_results * torch.flip(pred, (4, 3, 2))
+
+ if mult is not None:
+ result_torch[:, :] *= mult
+
+ return result_torch
+
+ def _internal_maybe_mirror_and_pred_2D(self, x: Union[np.ndarray, torch.tensor], mirror_axes: tuple,
+ do_mirroring: bool = True,
+ mult: np.ndarray or torch.tensor = None) -> torch.tensor:
+ # if cuda available:
+ # everything in here takes place on the GPU. If x and mult are not yet on GPU this will be taken care of here
+ # we now return a cuda tensor! Not numpy array!
+
+ assert len(x.shape) == 4, 'x must be (b, c, x, y)'
+
+ x = maybe_to_torch(x)
+ result_torch = torch.zeros([x.shape[0], self.num_classes] + list(x.shape[2:]), dtype=torch.float)
+
+ if torch.cuda.is_available():
+ x = to_cuda(x, gpu_id=self.get_device())
+ result_torch = result_torch.cuda(self.get_device(), non_blocking=True)
+
+ if mult is not None:
+ mult = maybe_to_torch(mult)
+ if torch.cuda.is_available():
+ mult = to_cuda(mult, gpu_id=self.get_device())
+
+ if do_mirroring:
+ mirror_idx = 4
+ num_results = 2 ** len(mirror_axes)
+ else:
+ mirror_idx = 1
+ num_results = 1
+
+ for m in range(mirror_idx):
+ if m == 0:
+ pred = self.inference_apply_nonlin(self(x))
+ result_torch += 1 / num_results * pred
+
+ if m == 1 and (1 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (3, ))))
+ result_torch += 1 / num_results * torch.flip(pred, (3, ))
+
+ if m == 2 and (0 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (2, ))))
+ result_torch += 1 / num_results * torch.flip(pred, (2, ))
+
+ if m == 3 and (0 in mirror_axes) and (1 in mirror_axes):
+ pred = self.inference_apply_nonlin(self(torch.flip(x, (3, 2))))
+ result_torch += 1 / num_results * torch.flip(pred, (3, 2))
+
+ if mult is not None:
+ result_torch[:, :] *= mult
+
+ return result_torch
+
+ def _internal_predict_2D_2Dconv_tiled(self, x: np.ndarray, step_size: float, do_mirroring: bool, mirror_axes: tuple,
+ patch_size: tuple, regions_class_order: tuple, use_gaussian: bool,
+ pad_border_mode: str, pad_kwargs: dict, all_in_gpu: bool,
+ verbose: bool) -> Tuple[np.ndarray, np.ndarray]:
+ # better safe than sorry
+ assert len(x.shape) == 3, "x must be (c, x, y)"
+
+ if verbose: print("step_size:", step_size)
+ if verbose: print("do mirror:", do_mirroring)
+
+ assert patch_size is not None, "patch_size cannot be None for tiled prediction"
+
+ # for sliding window inference the image must at least be as large as the patch size. It does not matter
+ # whether the shape is divisible by 2**num_pool as long as the patch size is
+ data, slicer = pad_nd_image(x, patch_size, pad_border_mode, pad_kwargs, True, None)
+ data_shape = data.shape # still c, x, y
+
+ # compute the steps for sliding window
+ steps = self._compute_steps_for_sliding_window(patch_size, data_shape[1:], step_size)
+ num_tiles = len(steps[0]) * len(steps[1])
+
+ if verbose:
+ print("data shape:", data_shape)
+ print("patch size:", patch_size)
+ print("steps (x, y, and z):", steps)
+ print("number of tiles:", num_tiles)
+
+ # we only need to compute that once. It can take a while to compute this due to the large sigma in
+ # gaussian_filter
+ if use_gaussian and num_tiles > 1:
+ if self._gaussian_2d is None or not all(
+ [i == j for i, j in zip(patch_size, self._patch_size_for_gaussian_2d)]):
+ if verbose: print('computing Gaussian')
+ gaussian_importance_map = self._get_gaussian(patch_size, sigma_scale=1. / 8)
+
+ self._gaussian_2d = gaussian_importance_map
+ self._patch_size_for_gaussian_2d = patch_size
+ else:
+ if verbose: print("using precomputed Gaussian")
+ gaussian_importance_map = self._gaussian_2d
+
+ gaussian_importance_map = torch.from_numpy(gaussian_importance_map)
+ if torch.cuda.is_available():
+ gaussian_importance_map = gaussian_importance_map.cuda(self.get_device(), non_blocking=True)
+
+ else:
+ gaussian_importance_map = None
+
+ if all_in_gpu:
+ # If we run the inference in GPU only (meaning all tensors are allocated on the GPU, this reduces
+ # CPU-GPU communication but required more GPU memory) we need to preallocate a few things on GPU
+
+ if use_gaussian and num_tiles > 1:
+ # half precision for the outputs should be good enough. If the outputs here are half, the
+ # gaussian_importance_map should be as well
+ gaussian_importance_map = gaussian_importance_map.half()
+
+ # make sure we did not round anything to 0
+ gaussian_importance_map[gaussian_importance_map == 0] = gaussian_importance_map[
+ gaussian_importance_map != 0].min()
+
+ add_for_nb_of_preds = gaussian_importance_map
+ else:
+ add_for_nb_of_preds = torch.ones(patch_size, device=self.get_device())
+
+ if verbose: print("initializing result array (on GPU)")
+ aggregated_results = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
+ device=self.get_device())
+
+ if verbose: print("moving data to GPU")
+ data = torch.from_numpy(data).cuda(self.get_device(), non_blocking=True)
+
+ if verbose: print("initializing result_numsamples (on GPU)")
+ aggregated_nb_of_predictions = torch.zeros([self.num_classes] + list(data.shape[1:]), dtype=torch.half,
+ device=self.get_device())
+ else:
+ if use_gaussian and num_tiles > 1:
+ add_for_nb_of_preds = self._gaussian_2d
+ else:
+ add_for_nb_of_preds = np.ones(patch_size, dtype=np.float32)
+ aggregated_results = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
+ aggregated_nb_of_predictions = np.zeros([self.num_classes] + list(data.shape[1:]), dtype=np.float32)
+
+ for x in steps[0]:
+ lb_x = x
+ ub_x = x + patch_size[0]
+ for y in steps[1]:
+ lb_y = y
+ ub_y = y + patch_size[1]
+
+ predicted_patch = self._internal_maybe_mirror_and_pred_2D(
+ data[None, :, lb_x:ub_x, lb_y:ub_y], mirror_axes, do_mirroring,
+ gaussian_importance_map)[0]
+
+ if all_in_gpu:
+ predicted_patch = predicted_patch.half()
+ else:
+ predicted_patch = predicted_patch.cpu().numpy()
+
+ aggregated_results[:, lb_x:ub_x, lb_y:ub_y] += predicted_patch
+ aggregated_nb_of_predictions[:, lb_x:ub_x, lb_y:ub_y] += add_for_nb_of_preds
+
+ # we reverse the padding here (remeber that we padded the input to be at least as large as the patch size
+ slicer = tuple(
+ [slice(0, aggregated_results.shape[i]) for i in
+ range(len(aggregated_results.shape) - (len(slicer) - 1))] + slicer[1:])
+ aggregated_results = aggregated_results[slicer]
+ aggregated_nb_of_predictions = aggregated_nb_of_predictions[slicer]
+
+ # computing the class_probabilities by dividing the aggregated result with result_numsamples
+ class_probabilities = aggregated_results / aggregated_nb_of_predictions
+
+ if regions_class_order is None:
+ predicted_segmentation = class_probabilities.argmax(0)
+ else:
+ if all_in_gpu:
+ class_probabilities_here = class_probabilities.detach().cpu().numpy()
+ else:
+ class_probabilities_here = class_probabilities
+ predicted_segmentation = np.zeros(class_probabilities_here.shape[1:], dtype=np.float32)
+ for i, c in enumerate(regions_class_order):
+ predicted_segmentation[class_probabilities_here[i] > 0.5] = c
+
+ if all_in_gpu:
+ if verbose: print("copying results to CPU")
+
+ if regions_class_order is None:
+ predicted_segmentation = predicted_segmentation.detach().cpu().numpy()
+
+ class_probabilities = class_probabilities.detach().cpu().numpy()
+
+ if verbose: print("prediction done")
+ return predicted_segmentation, class_probabilities
+
+ def _internal_predict_3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
+ mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
+ pad_border_mode: str = "constant", pad_kwargs: dict = None,
+ all_in_gpu: bool = False, verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ if all_in_gpu:
+ raise NotImplementedError
+ assert len(x.shape) == 4, "data must be c, x, y, z"
+ predicted_segmentation = []
+ softmax_pred = []
+ for s in range(x.shape[1]):
+ pred_seg, softmax_pres = self._internal_predict_2D_2Dconv(
+ x[:, s], min_size, do_mirroring, mirror_axes, regions_class_order, pad_border_mode, pad_kwargs, verbose)
+ predicted_segmentation.append(pred_seg[None])
+ softmax_pred.append(softmax_pres[None])
+ predicted_segmentation = np.vstack(predicted_segmentation)
+ softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
+ return predicted_segmentation, softmax_pred
+
+ def predict_3D_pseudo3D_2Dconv(self, x: np.ndarray, min_size: Tuple[int, int], do_mirroring: bool,
+ mirror_axes: tuple = (0, 1), regions_class_order: tuple = None,
+ pseudo3D_slices: int = 5, all_in_gpu: bool = False,
+ pad_border_mode: str = "constant", pad_kwargs: dict = None,
+ verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ if all_in_gpu:
+ raise NotImplementedError
+ assert len(x.shape) == 4, "data must be c, x, y, z"
+ assert pseudo3D_slices % 2 == 1, "pseudo3D_slices must be odd"
+ extra_slices = (pseudo3D_slices - 1) // 2
+
+ shp_for_pad = np.array(x.shape)
+ shp_for_pad[1] = extra_slices
+
+ pad = np.zeros(shp_for_pad, dtype=np.float32)
+ data = np.concatenate((pad, x, pad), 1)
+
+ predicted_segmentation = []
+ softmax_pred = []
+ for s in range(extra_slices, data.shape[1] - extra_slices):
+ d = data[:, (s - extra_slices):(s + extra_slices + 1)]
+ d = d.reshape((-1, d.shape[-2], d.shape[-1]))
+ pred_seg, softmax_pres = \
+ self._internal_predict_2D_2Dconv(d, min_size, do_mirroring, mirror_axes,
+ regions_class_order, pad_border_mode, pad_kwargs, verbose)
+ predicted_segmentation.append(pred_seg[None])
+ softmax_pred.append(softmax_pres[None])
+ predicted_segmentation = np.vstack(predicted_segmentation)
+ softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
+
+ return predicted_segmentation, softmax_pred
+
+ def _internal_predict_3D_2Dconv_tiled(self, x: np.ndarray, patch_size: Tuple[int, int], do_mirroring: bool,
+ mirror_axes: tuple = (0, 1), step_size: float = 0.5,
+ regions_class_order: tuple = None, use_gaussian: bool = False,
+ pad_border_mode: str = "edge", pad_kwargs: dict =None,
+ all_in_gpu: bool = False,
+ verbose: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ if all_in_gpu:
+ raise NotImplementedError
+
+ assert len(x.shape) == 4, "data must be c, x, y, z"
+
+ predicted_segmentation = []
+ softmax_pred = []
+
+ for s in range(x.shape[1]):
+ pred_seg, softmax_pres = self._internal_predict_2D_2Dconv_tiled(
+ x[:, s], step_size, do_mirroring, mirror_axes, patch_size, regions_class_order, use_gaussian,
+ pad_border_mode, pad_kwargs, all_in_gpu, verbose)
+
+ predicted_segmentation.append(pred_seg[None])
+ softmax_pred.append(softmax_pres[None])
+
+ predicted_segmentation = np.vstack(predicted_segmentation)
+ softmax_pred = np.vstack(softmax_pred).transpose((1, 0, 2, 3))
+
+ return predicted_segmentation, softmax_pred
+
+
+if __name__ == '__main__':
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.5))
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 1))
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (162, 529, 529), 0.1))
+
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 1))
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (60, 448, 224), 0.5))
+
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 1))
+ print(SegmentationNetwork._compute_steps_for_sliding_window((30, 224, 224), (30, 224, 224), 0.125))
+
+
+ print(SegmentationNetwork._compute_steps_for_sliding_window((123, 54, 123), (246, 162, 369), 0.25))
+
+
+
diff --git a/nnunet/paths.py b/nnunet/paths.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fc5af5b60639bc8f570f6f015262243871df596
--- /dev/null
+++ b/nnunet/paths.py
@@ -0,0 +1,58 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join
+
+# do not modify these unless you know what you are doing
+my_output_identifier = "nnUNet"
+default_plans_identifier = "nnUNetPlansv2.1"
+default_data_identifier = 'nnUNetData_plans_v2.1'
+default_trainer = "nnUNetTrainerV2"
+default_cascade_trainer = "nnUNetTrainerV2CascadeFullRes"
+
+"""
+PLEASE READ paths.md FOR INFORMATION TO HOW TO SET THIS UP
+"""
+
+base = os.environ['nnUNet_raw_data_base'] if "nnUNet_raw_data_base" in os.environ.keys() else None
+preprocessing_output_dir = os.environ['nnUNet_preprocessed'] if "nnUNet_preprocessed" in os.environ.keys() else None
+network_training_output_dir_base = os.path.join(os.environ['RESULTS_FOLDER']) if "RESULTS_FOLDER" in os.environ.keys() else None
+
+if base is not None:
+ nnUNet_raw_data = join(base, "nnUNet_raw_data")
+ nnUNet_cropped_data = join(base, "nnUNet_cropped_data")
+ maybe_mkdir_p(nnUNet_raw_data)
+ maybe_mkdir_p(nnUNet_cropped_data)
+else:
+ print("nnUNet_raw_data_base is not defined and nnU-Net can only be used on data for which preprocessed files "
+ "are already present on your system. nnU-Net cannot be used for experiment planning and preprocessing like "
+ "this. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up properly.")
+ nnUNet_cropped_data = nnUNet_raw_data = None
+
+if preprocessing_output_dir is not None:
+ maybe_mkdir_p(preprocessing_output_dir)
+else:
+ print("nnUNet_preprocessed is not defined and nnU-Net can not be used for preprocessing "
+ "or training. If this is not intended, please read documentation/setting_up_paths.md for information on how to set this up.")
+ preprocessing_output_dir = None
+
+if network_training_output_dir_base is not None:
+ network_training_output_dir = join(network_training_output_dir_base, my_output_identifier)
+ maybe_mkdir_p(network_training_output_dir)
+else:
+ print("RESULTS_FOLDER is not defined and nnU-Net cannot be used for training or "
+ "inference. If this is not intended behavior, please read documentation/setting_up_paths.md for information on how to set this "
+ "up.")
+ network_training_output_dir = None
diff --git a/nnunet/postprocessing/__init__.py b/nnunet/postprocessing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/postprocessing/connected_components.py b/nnunet/postprocessing/connected_components.py
new file mode 100644
index 0000000000000000000000000000000000000000..c69471ea9e366829d4490008afe2c51001edbb5c
--- /dev/null
+++ b/nnunet/postprocessing/connected_components.py
@@ -0,0 +1,428 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import ast
+from copy import deepcopy
+from multiprocessing.pool import Pool
+
+import numpy as np
+from nnunet.configuration import default_num_threads
+from nnunet.evaluation.evaluator import aggregate_scores
+from scipy.ndimage import label
+import SimpleITK as sitk
+from nnunet.utilities.sitk_stuff import copy_geometry
+from batchgenerators.utilities.file_and_folder_operations import *
+import shutil
+
+
+def load_remove_save(input_file: str, output_file: str, for_which_classes: list,
+ minimum_valid_object_size: dict = None):
+ # Only objects larger than minimum_valid_object_size will be removed. Keys in minimum_valid_object_size must
+ # match entries in for_which_classes
+ img_in = sitk.ReadImage(input_file)
+ img_npy = sitk.GetArrayFromImage(img_in)
+ volume_per_voxel = float(np.prod(img_in.GetSpacing(), dtype=np.float64))
+
+ image, largest_removed, kept_size = remove_all_but_the_largest_connected_component(img_npy, for_which_classes,
+ volume_per_voxel,
+ minimum_valid_object_size)
+ # print(input_file, "kept:", kept_size)
+ img_out_itk = sitk.GetImageFromArray(image)
+ img_out_itk = copy_geometry(img_out_itk, img_in)
+ sitk.WriteImage(img_out_itk, output_file)
+ return largest_removed, kept_size
+
+
+def remove_all_but_the_largest_connected_component(image: np.ndarray, for_which_classes: list, volume_per_voxel: float,
+ minimum_valid_object_size: dict = None):
+ """
+ removes all but the largest connected component, individually for each class
+ :param image:
+ :param for_which_classes: can be None. Should be list of int. Can also be something like [(1, 2), 2, 4].
+ Here (1, 2) will be treated as a joint region, not individual classes (example LiTS here we can use (1, 2)
+ to use all foreground classes together)
+ :param minimum_valid_object_size: Only objects larger than minimum_valid_object_size will be removed. Keys in
+ minimum_valid_object_size must match entries in for_which_classes
+ :return:
+ """
+ if for_which_classes is None:
+ for_which_classes = np.unique(image)
+ for_which_classes = for_which_classes[for_which_classes > 0]
+
+ assert 0 not in for_which_classes, "cannot remove background"
+ largest_removed = {}
+ kept_size = {}
+ for c in for_which_classes:
+ if isinstance(c, (list, tuple)):
+ c = tuple(c) # otherwise it cant be used as key in the dict
+ mask = np.zeros_like(image, dtype=bool)
+ for cl in c:
+ mask[image == cl] = True
+ else:
+ mask = image == c
+ # get labelmap and number of objects
+ lmap, num_objects = label(mask.astype(int))
+
+ # collect object sizes
+ object_sizes = {}
+ for object_id in range(1, num_objects + 1):
+ object_sizes[object_id] = (lmap == object_id).sum() * volume_per_voxel
+
+ largest_removed[c] = None
+ kept_size[c] = None
+
+ if num_objects > 0:
+ # we always keep the largest object. We could also consider removing the largest object if it is smaller
+ # than minimum_valid_object_size in the future but we don't do that now.
+ maximum_size = max(object_sizes.values())
+ kept_size[c] = maximum_size
+
+ for object_id in range(1, num_objects + 1):
+ # we only remove objects that are not the largest
+ if object_sizes[object_id] != maximum_size:
+ # we only remove objects that are smaller than minimum_valid_object_size
+ remove = True
+ if minimum_valid_object_size is not None:
+ remove = object_sizes[object_id] < minimum_valid_object_size[c]
+ if remove:
+ image[(lmap == object_id) & mask] = 0
+ if largest_removed[c] is None:
+ largest_removed[c] = object_sizes[object_id]
+ else:
+ largest_removed[c] = max(largest_removed[c], object_sizes[object_id])
+ return image, largest_removed, kept_size
+
+
+def load_postprocessing(json_file):
+ '''
+ loads the relevant part of the pkl file that is needed for applying postprocessing
+ :param pkl_file:
+ :return:
+ '''
+ a = load_json(json_file)
+ if 'min_valid_object_sizes' in a.keys():
+ min_valid_object_sizes = ast.literal_eval(a['min_valid_object_sizes'])
+ else:
+ min_valid_object_sizes = None
+ return a['for_which_classes'], min_valid_object_sizes
+
+
+def determine_postprocessing(base, gt_labels_folder, raw_subfolder_name="validation_raw",
+ temp_folder="temp",
+ final_subf_name="validation_final", processes=default_num_threads,
+ dice_threshold=0, debug=False,
+ advanced_postprocessing=False,
+ pp_filename="postprocessing.json"):
+ """
+ :param base:
+ :param gt_labels_folder: subfolder of base with niftis of ground truth labels
+ :param raw_subfolder_name: subfolder of base with niftis of predicted (non-postprocessed) segmentations
+ :param temp_folder: used to store temporary data, will be deleted after we are done here undless debug=True
+ :param final_subf_name: final results will be stored here (subfolder of base)
+ :param processes:
+ :param dice_threshold: only apply postprocessing if results is better than old_result+dice_threshold (can be used as eps)
+ :param debug: if True then the temporary files will not be deleted
+ :return:
+ """
+ # lets see what classes are in the dataset
+ classes = [int(i) for i in load_json(join(base, raw_subfolder_name, "summary.json"))['results']['mean'].keys() if
+ int(i) != 0]
+
+ folder_all_classes_as_fg = join(base, temp_folder + "_allClasses")
+ folder_per_class = join(base, temp_folder + "_perClass")
+
+ if isdir(folder_all_classes_as_fg):
+ shutil.rmtree(folder_all_classes_as_fg)
+ if isdir(folder_per_class):
+ shutil.rmtree(folder_per_class)
+
+ # multiprocessing rules
+ p = Pool(processes)
+
+ assert isfile(join(base, raw_subfolder_name, "summary.json")), "join(base, raw_subfolder_name) does not " \
+ "contain a summary.json"
+
+ # these are all the files we will be dealing with
+ fnames = subfiles(join(base, raw_subfolder_name), suffix=".nii.gz", join=False)
+
+ # make output and temp dir
+ maybe_mkdir_p(folder_all_classes_as_fg)
+ maybe_mkdir_p(folder_per_class)
+ maybe_mkdir_p(join(base, final_subf_name))
+
+ pp_results = {}
+ pp_results['dc_per_class_raw'] = {}
+ pp_results['dc_per_class_pp_all'] = {} # dice scores after treating all foreground classes as one
+ pp_results['dc_per_class_pp_per_class'] = {} # dice scores after removing everything except larges cc
+ # independently for each class after we already did dc_per_class_pp_all
+ pp_results['for_which_classes'] = []
+ pp_results['min_valid_object_sizes'] = {}
+
+
+ validation_result_raw = load_json(join(base, raw_subfolder_name, "summary.json"))['results']
+ pp_results['num_samples'] = len(validation_result_raw['all'])
+ validation_result_raw = validation_result_raw['mean']
+
+ if advanced_postprocessing:
+ # first treat all foreground classes as one and remove all but the largest foreground connected component
+ results = []
+ for f in fnames:
+ predicted_segmentation = join(base, raw_subfolder_name, f)
+ # now remove all but the largest connected component for each class
+ output_file = join(folder_all_classes_as_fg, f)
+ results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,)),)))
+
+ results = [i.get() for i in results]
+
+ # aggregate max_size_removed and min_size_kept
+ max_size_removed = {}
+ min_size_kept = {}
+ for tmp in results:
+ mx_rem, min_kept = tmp[0]
+ for k in mx_rem:
+ if mx_rem[k] is not None:
+ if max_size_removed.get(k) is None:
+ max_size_removed[k] = mx_rem[k]
+ else:
+ max_size_removed[k] = max(max_size_removed[k], mx_rem[k])
+ for k in min_kept:
+ if min_kept[k] is not None:
+ if min_size_kept.get(k) is None:
+ min_size_kept[k] = min_kept[k]
+ else:
+ min_size_kept[k] = min(min_size_kept[k], min_kept[k])
+
+ print("foreground vs background, smallest valid object size was", min_size_kept[tuple(classes)])
+ print("removing only objects smaller than that...")
+
+ else:
+ min_size_kept = None
+
+ # we need to rerun the step from above, now with the size constraint
+ pred_gt_tuples = []
+ results = []
+ # first treat all foreground classes as one and remove all but the largest foreground connected component
+ for f in fnames:
+ predicted_segmentation = join(base, raw_subfolder_name, f)
+ # now remove all but the largest connected component for each class
+ output_file = join(folder_all_classes_as_fg, f)
+ results.append(
+ p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, (classes,), min_size_kept),)))
+ pred_gt_tuples.append([output_file, join(gt_labels_folder, f)])
+
+ _ = [i.get() for i in results]
+
+ # evaluate postprocessed predictions
+ _ = aggregate_scores(pred_gt_tuples, labels=classes,
+ json_output_file=join(folder_all_classes_as_fg, "summary.json"),
+ json_author="Fabian", num_threads=processes)
+
+ # now we need to figure out if doing this improved the dice scores. We will implement that defensively in so far
+ # that if a single class got worse as a result we won't do this. We can change this in the future but right now I
+ # prefer to do it this way
+ validation_result_PP_test = load_json(join(folder_all_classes_as_fg, "summary.json"))['results']['mean']
+
+ for c in classes:
+ dc_raw = validation_result_raw[str(c)]['Dice']
+ dc_pp = validation_result_PP_test[str(c)]['Dice']
+ pp_results['dc_per_class_raw'][str(c)] = dc_raw
+ pp_results['dc_per_class_pp_all'][str(c)] = dc_pp
+
+ # true if new is better
+ do_fg_cc = False
+ comp = [pp_results['dc_per_class_pp_all'][str(cl)] > (pp_results['dc_per_class_raw'][str(cl)] + dice_threshold) for
+ cl in classes]
+ before = np.mean([pp_results['dc_per_class_raw'][str(cl)] for cl in classes])
+ after = np.mean([pp_results['dc_per_class_pp_all'][str(cl)] for cl in classes])
+ print("Foreground vs background")
+ print("before:", before)
+ print("after: ", after)
+ if any(comp):
+ # at least one class improved - yay!
+ # now check if another got worse
+ # true if new is worse
+ any_worse = any(
+ [pp_results['dc_per_class_pp_all'][str(cl)] < pp_results['dc_per_class_raw'][str(cl)] for cl in classes])
+ if not any_worse:
+ pp_results['for_which_classes'].append(classes)
+ if min_size_kept is not None:
+ pp_results['min_valid_object_sizes'].update(deepcopy(min_size_kept))
+ do_fg_cc = True
+ print("Removing all but the largest foreground region improved results!")
+ print('for_which_classes', classes)
+ print('min_valid_object_sizes', min_size_kept)
+ else:
+ # did not improve things - don't do it
+ pass
+
+ if len(classes) > 1:
+ # now depending on whether we do remove all but the largest foreground connected component we define the source dir
+ # for the next one to be the raw or the temp dir
+ if do_fg_cc:
+ source = folder_all_classes_as_fg
+ else:
+ source = join(base, raw_subfolder_name)
+
+ if advanced_postprocessing:
+ # now run this for each class separately
+ results = []
+ for f in fnames:
+ predicted_segmentation = join(source, f)
+ output_file = join(folder_per_class, f)
+ results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes),)))
+
+ results = [i.get() for i in results]
+
+ # aggregate max_size_removed and min_size_kept
+ max_size_removed = {}
+ min_size_kept = {}
+ for tmp in results:
+ mx_rem, min_kept = tmp[0]
+ for k in mx_rem:
+ if mx_rem[k] is not None:
+ if max_size_removed.get(k) is None:
+ max_size_removed[k] = mx_rem[k]
+ else:
+ max_size_removed[k] = max(max_size_removed[k], mx_rem[k])
+ for k in min_kept:
+ if min_kept[k] is not None:
+ if min_size_kept.get(k) is None:
+ min_size_kept[k] = min_kept[k]
+ else:
+ min_size_kept[k] = min(min_size_kept[k], min_kept[k])
+
+ print("classes treated separately, smallest valid object sizes are")
+ print(min_size_kept)
+ print("removing only objects smaller than that")
+ else:
+ min_size_kept = None
+
+ # rerun with the size thresholds from above
+ pred_gt_tuples = []
+ results = []
+ for f in fnames:
+ predicted_segmentation = join(source, f)
+ output_file = join(folder_per_class, f)
+ results.append(p.starmap_async(load_remove_save, ((predicted_segmentation, output_file, classes, min_size_kept),)))
+ pred_gt_tuples.append([output_file, join(gt_labels_folder, f)])
+
+ _ = [i.get() for i in results]
+
+ # evaluate postprocessed predictions
+ _ = aggregate_scores(pred_gt_tuples, labels=classes,
+ json_output_file=join(folder_per_class, "summary.json"),
+ json_author="Fabian", num_threads=processes)
+
+ if do_fg_cc:
+ old_res = deepcopy(validation_result_PP_test)
+ else:
+ old_res = validation_result_raw
+
+ # these are the new dice scores
+ validation_result_PP_test = load_json(join(folder_per_class, "summary.json"))['results']['mean']
+
+ for c in classes:
+ dc_raw = old_res[str(c)]['Dice']
+ dc_pp = validation_result_PP_test[str(c)]['Dice']
+ pp_results['dc_per_class_pp_per_class'][str(c)] = dc_pp
+ print(c)
+ print("before:", dc_raw)
+ print("after: ", dc_pp)
+
+ if dc_pp > (dc_raw + dice_threshold):
+ pp_results['for_which_classes'].append(int(c))
+ if min_size_kept is not None:
+ pp_results['min_valid_object_sizes'].update({c: min_size_kept[c]})
+ print("Removing all but the largest region for class %d improved results!" % c)
+ print('min_valid_object_sizes', min_size_kept)
+ else:
+ print("Only one class present, no need to do each class separately as this is covered in fg vs bg")
+
+ if not advanced_postprocessing:
+ pp_results['min_valid_object_sizes'] = None
+
+ print("done")
+ print("for which classes:")
+ print(pp_results['for_which_classes'])
+ print("min_object_sizes")
+ print(pp_results['min_valid_object_sizes'])
+
+ pp_results['validation_raw'] = raw_subfolder_name
+ pp_results['validation_final'] = final_subf_name
+
+ # now that we have a proper for_which_classes, apply that
+ pred_gt_tuples = []
+ results = []
+ for f in fnames:
+ predicted_segmentation = join(base, raw_subfolder_name, f)
+
+ # now remove all but the largest connected component for each class
+ output_file = join(base, final_subf_name, f)
+ results.append(p.starmap_async(load_remove_save, (
+ (predicted_segmentation, output_file, pp_results['for_which_classes'],
+ pp_results['min_valid_object_sizes']),)))
+
+ pred_gt_tuples.append([output_file,
+ join(gt_labels_folder, f)])
+
+ _ = [i.get() for i in results]
+ # evaluate postprocessed predictions
+ _ = aggregate_scores(pred_gt_tuples, labels=classes,
+ json_output_file=join(base, final_subf_name, "summary.json"),
+ json_author="Fabian", num_threads=processes)
+
+ pp_results['min_valid_object_sizes'] = str(pp_results['min_valid_object_sizes'])
+
+ save_json(pp_results, join(base, pp_filename))
+
+ # delete temp
+ if not debug:
+ shutil.rmtree(folder_per_class)
+ shutil.rmtree(folder_all_classes_as_fg)
+
+ p.close()
+ p.join()
+ print("done")
+
+
+def apply_postprocessing_to_folder(input_folder: str, output_folder: str, for_which_classes: list,
+ min_valid_object_size:dict=None, num_processes=8):
+ """
+ applies removing of all but the largest connected component to all niftis in a folder
+ :param min_valid_object_size:
+ :param min_valid_object_size:
+ :param input_folder:
+ :param output_folder:
+ :param for_which_classes:
+ :param num_processes:
+ :return:
+ """
+ maybe_mkdir_p(output_folder)
+ p = Pool(num_processes)
+ nii_files = subfiles(input_folder, suffix=".nii.gz", join=False)
+ input_files = [join(input_folder, i) for i in nii_files]
+ out_files = [join(output_folder, i) for i in nii_files]
+ results = p.starmap_async(load_remove_save, zip(input_files, out_files, [for_which_classes] * len(input_files),
+ [min_valid_object_size] * len(input_files)))
+ res = results.get()
+ p.close()
+ p.join()
+
+
+if __name__ == "__main__":
+ input_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor"
+ output_folder = "/media/fabian/DKFZ/predictions_Fabian/Liver_and_LiverTumor_postprocessed"
+ for_which_classes = [(1, 2), ]
+ apply_postprocessing_to_folder(input_folder, output_folder, for_which_classes)
diff --git a/nnunet/postprocessing/consolidate_all_for_paper.py b/nnunet/postprocessing/consolidate_all_for_paper.py
new file mode 100644
index 0000000000000000000000000000000000000000..51c787261b4662894cf5b356c185669fc6eb33b9
--- /dev/null
+++ b/nnunet/postprocessing/consolidate_all_for_paper.py
@@ -0,0 +1,61 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.utilities.folder_names import get_output_folder_name
+
+
+def get_datasets():
+ configurations_all = {
+ "Task01_BrainTumour": ("3d_fullres", "2d"),
+ "Task02_Heart": ("3d_fullres", "2d",),
+ "Task03_Liver": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task04_Hippocampus": ("3d_fullres", "2d",),
+ "Task05_Prostate": ("3d_fullres", "2d",),
+ "Task06_Lung": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task07_Pancreas": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task08_HepaticVessel": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task09_Spleen": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task10_Colon": ("3d_cascade_fullres", "3d_fullres", "3d_lowres", "2d"),
+ "Task48_KiTS_clean": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d"),
+ "Task27_ACDC": ("3d_fullres", "2d",),
+ "Task24_Promise": ("3d_fullres", "2d",),
+ "Task35_ISBILesionSegmentation": ("3d_fullres", "2d",),
+ "Task38_CHAOS_Task_3_5_Variant2": ("3d_fullres", "2d",),
+ "Task29_LITS": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",),
+ "Task17_AbdominalOrganSegmentation": ("3d_cascade_fullres", "3d_lowres", "2d", "3d_fullres",),
+ "Task55_SegTHOR": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",),
+ "Task56_VerSe": ("3d_cascade_fullres", "3d_lowres", "3d_fullres", "2d",),
+ }
+ return configurations_all
+
+
+def get_commands(configurations, regular_trainer="nnUNetTrainerV2", cascade_trainer="nnUNetTrainerV2CascadeFullRes",
+ plans="nnUNetPlansv2.1"):
+
+ node_pool = ["hdf18-gpu%02.0d" % i for i in range(1, 21)] + ["hdf19-gpu%02.0d" % i for i in range(1, 8)] + ["hdf19-gpu%02.0d" % i for i in range(11, 16)]
+ ctr = 0
+ for task in configurations:
+ models = configurations[task]
+ for m in models:
+ if m == "3d_cascade_fullres":
+ trainer = cascade_trainer
+ else:
+ trainer = regular_trainer
+
+ folder = get_output_folder_name(m, task, trainer, plans, overwrite_training_output_dir="/datasets/datasets_fabian/results/nnUNet")
+ node = node_pool[ctr % len(node_pool)]
+ print("bsub -m %s -q gputest -L /bin/bash \"source ~/.bashrc && python postprocessing/"
+ "consolidate_postprocessing.py -f" % node, folder, "\"")
+ ctr += 1
diff --git a/nnunet/postprocessing/consolidate_postprocessing.py b/nnunet/postprocessing/consolidate_postprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..e735d74821922b7504252a6455e93419db083b36
--- /dev/null
+++ b/nnunet/postprocessing/consolidate_postprocessing.py
@@ -0,0 +1,97 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import shutil
+from typing import Tuple
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.configuration import default_num_threads
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.postprocessing.connected_components import determine_postprocessing
+import argparse
+
+
+def collect_cv_niftis(cv_folder: str, output_folder: str, validation_folder_name: str = 'validation_raw',
+ folds: tuple = (0, 1, 2, 3, 4)):
+ validation_raw_folders = [join(cv_folder, "fold_%d" % i, validation_folder_name) for i in folds]
+ exist = [isdir(i) for i in validation_raw_folders]
+
+ if not all(exist):
+ raise RuntimeError("some folds are missing. Please run the full 5-fold cross-validation. "
+ "The following folds seem to be missing: %s" %
+ [i for j, i in enumerate(folds) if not exist[j]])
+
+ # now copy all raw niftis into cv_niftis_raw
+ maybe_mkdir_p(output_folder)
+ for f in folds:
+ niftis = subfiles(validation_raw_folders[f], suffix=".nii.gz")
+ for n in niftis:
+ shutil.copy(n, join(output_folder))
+
+
+def consolidate_folds(output_folder_base, validation_folder_name: str = 'validation_raw',
+ advanced_postprocessing: bool = False, folds: Tuple[int] = (0, 1, 2, 3, 4)):
+ """
+ Used to determine the postprocessing for an experiment after all five folds have been completed. In the validation of
+ each fold, the postprocessing can only be determined on the cases within that fold. This can result in different
+ postprocessing decisions for different folds. In the end, we can only decide for one postprocessing per experiment,
+ so we have to rerun it
+ :param folds:
+ :param advanced_postprocessing:
+ :param output_folder_base:experiment output folder (fold_0, fold_1, etc must be subfolders of the given folder)
+ :param validation_folder_name: dont use this
+ :return:
+ """
+ output_folder_raw = join(output_folder_base, "cv_niftis_raw")
+ if isdir(output_folder_raw):
+ shutil.rmtree(output_folder_raw)
+
+ output_folder_gt = join(output_folder_base, "gt_niftis")
+ collect_cv_niftis(output_folder_base, output_folder_raw, validation_folder_name,
+ folds)
+
+ num_niftis_gt = len(subfiles(join(output_folder_base, "gt_niftis"), suffix='.nii.gz'))
+ # count niftis in there
+ num_niftis = len(subfiles(output_folder_raw, suffix='.nii.gz'))
+ if num_niftis != num_niftis_gt:
+ raise AssertionError("If does not seem like you trained all the folds! Train all folds first!")
+
+ # load a summary file so that we can know what class labels to expect
+ summary_fold0 = load_json(join(output_folder_base, "fold_0", validation_folder_name, "summary.json"))['results'][
+ 'mean']
+ classes = [int(i) for i in summary_fold0.keys()]
+ niftis = subfiles(output_folder_raw, join=False, suffix=".nii.gz")
+ test_pred_pairs = [(join(output_folder_raw, i), join(output_folder_gt, i)) for i in niftis]
+
+ # determine_postprocessing needs a summary.json file in the folder where the raw predictions are. We could compute
+ # that from the summary files of the five folds but I am feeling lazy today
+ aggregate_scores(test_pred_pairs, labels=classes, json_output_file=join(output_folder_raw, "summary.json"),
+ num_threads=default_num_threads)
+
+ determine_postprocessing(output_folder_base, output_folder_gt, 'cv_niftis_raw',
+ final_subf_name="cv_niftis_postprocessed", processes=default_num_threads,
+ advanced_postprocessing=advanced_postprocessing)
+ # determine_postprocessing will create a postprocessing.json file that can be used for inference
+
+
+if __name__ == "__main__":
+ argparser = argparse.ArgumentParser()
+ argparser.add_argument("-f", type=str, required=True, help="experiment output folder (fold_0, fold_1, "
+ "etc must be subfolders of the given folder)")
+
+ args = argparser.parse_args()
+
+ folder = args.f
+
+ consolidate_folds(folder)
diff --git a/nnunet/postprocessing/consolidate_postprocessing_simple.py b/nnunet/postprocessing/consolidate_postprocessing_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..34236445843b53c59bd8526e96218d0695f7ae8f
--- /dev/null
+++ b/nnunet/postprocessing/consolidate_postprocessing_simple.py
@@ -0,0 +1,60 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+from nnunet.postprocessing.consolidate_postprocessing import consolidate_folds
+from nnunet.utilities.folder_names import get_output_folder_name
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+from nnunet.paths import default_cascade_trainer, default_trainer, default_plans_identifier
+
+
+def main():
+ argparser = argparse.ArgumentParser(usage="Used to determine the postprocessing for a trained model. Useful for "
+ "when the best configuration (2d, 3d_fullres etc) as selected manually.")
+ argparser.add_argument("-m", type=str, required=True, help="U-Net model (2d, 3d_lowres, 3d_fullres or "
+ "3d_cascade_fullres)")
+ argparser.add_argument("-t", type=str, required=True, help="Task name or id")
+ argparser.add_argument("-tr", type=str, required=False, default=None,
+ help="nnUNetTrainer class. Default: %s, unless 3d_cascade_fullres "
+ "(then it's %s)" % (default_trainer, default_cascade_trainer))
+ argparser.add_argument("-pl", type=str, required=False, default=default_plans_identifier,
+ help="Plans name, Default=%s" % default_plans_identifier)
+ argparser.add_argument("-val", type=str, required=False, default="validation_raw",
+ help="Validation folder name. Default: validation_raw")
+
+ args = argparser.parse_args()
+ model = args.m
+ task = args.t
+ trainer = args.tr
+ plans = args.pl
+ val = args.val
+
+ if not task.startswith("Task"):
+ task_id = int(task)
+ task = convert_id_to_task_name(task_id)
+
+ if trainer is None:
+ if model == "3d_cascade_fullres":
+ trainer = "nnUNetTrainerV2CascadeFullRes"
+ else:
+ trainer = "nnUNetTrainerV2"
+
+ folder = get_output_folder_name(model, task, trainer, plans, None)
+
+ consolidate_folds(folder, val)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/preprocessing/__init__.py b/nnunet/preprocessing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/preprocessing/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/preprocessing/cropping.py b/nnunet/preprocessing/cropping.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb0a92acb10ed13f9aad3e3781939f73476b92fb
--- /dev/null
+++ b/nnunet/preprocessing/cropping.py
@@ -0,0 +1,216 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import SimpleITK as sitk
+import numpy as np
+import shutil
+from batchgenerators.utilities.file_and_folder_operations import *
+from multiprocessing import Pool
+from collections import OrderedDict
+
+
+def create_nonzero_mask(data):
+ from scipy.ndimage import binary_fill_holes
+ assert len(data.shape) == 4 or len(data.shape) == 3, "data must have shape (C, X, Y, Z) or shape (C, X, Y)"
+ nonzero_mask = np.zeros(data.shape[1:], dtype=bool)
+ for c in range(data.shape[0]):
+ this_mask = data[c] != 0
+ nonzero_mask = nonzero_mask | this_mask
+ nonzero_mask = binary_fill_holes(nonzero_mask)
+ return nonzero_mask
+
+
+def get_bbox_from_mask(mask, outside_value=0):
+ mask_voxel_coords = np.where(mask != outside_value)
+ minzidx = int(np.min(mask_voxel_coords[0]))
+ maxzidx = int(np.max(mask_voxel_coords[0])) + 1
+ minxidx = int(np.min(mask_voxel_coords[1]))
+ maxxidx = int(np.max(mask_voxel_coords[1])) + 1
+ minyidx = int(np.min(mask_voxel_coords[2]))
+ maxyidx = int(np.max(mask_voxel_coords[2])) + 1
+ return [[minzidx, maxzidx], [minxidx, maxxidx], [minyidx, maxyidx]]
+
+
+def crop_to_bbox(image, bbox):
+ assert len(image.shape) == 3, "only supports 3d images"
+ resizer = (slice(bbox[0][0], bbox[0][1]), slice(bbox[1][0], bbox[1][1]), slice(bbox[2][0], bbox[2][1]))
+ return image[resizer]
+
+
+def get_case_identifier(case):
+ case_identifier = case[0].split("/")[-1].split(".nii.gz")[0][:-5]
+ return case_identifier
+
+
+def get_case_identifier_from_npz(case):
+ case_identifier = case.split("/")[-1][:-4]
+ return case_identifier
+
+
+def load_case_from_list_of_files(data_files, seg_file=None):
+ assert isinstance(data_files, list) or isinstance(data_files, tuple), "case must be either a list or a tuple"
+ properties = OrderedDict()
+ data_itk = [sitk.ReadImage(f) for f in data_files]
+
+ properties["original_size_of_raw_data"] = np.array(data_itk[0].GetSize())[[2, 1, 0]]
+ properties["original_spacing"] = np.array(data_itk[0].GetSpacing())[[2, 1, 0]]
+ properties["list_of_data_files"] = data_files
+ properties["seg_file"] = seg_file
+
+ properties["itk_origin"] = data_itk[0].GetOrigin()
+ properties["itk_spacing"] = data_itk[0].GetSpacing()
+ properties["itk_direction"] = data_itk[0].GetDirection()
+
+ data_npy = np.vstack([sitk.GetArrayFromImage(d)[None] for d in data_itk])
+ if seg_file is not None:
+ seg_itk = sitk.ReadImage(seg_file)
+ seg_npy = sitk.GetArrayFromImage(seg_itk)[None].astype(np.float32)
+ else:
+ seg_npy = None
+ return data_npy.astype(np.float32), seg_npy, properties
+
+
+def crop_to_nonzero(data, seg=None, nonzero_label=-1):
+ """
+
+ :param data:
+ :param seg:
+ :param nonzero_label: this will be written into the segmentation map
+ :return:
+ """
+ nonzero_mask = create_nonzero_mask(data)
+ bbox = get_bbox_from_mask(nonzero_mask, 0)
+
+ cropped_data = []
+ for c in range(data.shape[0]):
+ cropped = crop_to_bbox(data[c], bbox)
+ cropped_data.append(cropped[None])
+ data = np.vstack(cropped_data)
+
+ if seg is not None:
+ cropped_seg = []
+ for c in range(seg.shape[0]):
+ cropped = crop_to_bbox(seg[c], bbox)
+ cropped_seg.append(cropped[None])
+ seg = np.vstack(cropped_seg)
+
+ nonzero_mask = crop_to_bbox(nonzero_mask, bbox)[None]
+ if seg is not None:
+ seg[(seg == 0) & (nonzero_mask == 0)] = nonzero_label
+ else:
+ nonzero_mask = nonzero_mask.astype(int)
+ nonzero_mask[nonzero_mask == 0] = nonzero_label
+ nonzero_mask[nonzero_mask > 0] = 0
+ seg = nonzero_mask
+ return data, seg, bbox
+
+
+def get_patient_identifiers_from_cropped_files(folder):
+ return [i.split("/")[-1][:-4] for i in subfiles(folder, join=True, suffix=".npz")]
+
+
+class ImageCropper(object):
+ def __init__(self, num_threads, output_folder=None):
+ """
+ This one finds a mask of nonzero elements (must be nonzero in all modalities) and crops the image to that mask.
+ In the case of BRaTS and ISLES data this results in a significant reduction in image size
+ :param num_threads:
+ :param output_folder: whete to store the cropped data
+ :param list_of_files:
+ """
+ self.output_folder = output_folder
+ self.num_threads = num_threads
+
+ if self.output_folder is not None:
+ maybe_mkdir_p(self.output_folder)
+
+ @staticmethod
+ def crop(data, properties, seg=None):
+ shape_before = data.shape
+ data, seg, bbox = crop_to_nonzero(data, seg, nonzero_label=-1)
+ shape_after = data.shape
+ print("before crop:", shape_before, "after crop:", shape_after, "spacing:",
+ np.array(properties["original_spacing"]), "\n")
+
+ properties["crop_bbox"] = bbox
+ properties['classes'] = np.unique(seg)
+ seg[seg < -1] = 0
+ properties["size_after_cropping"] = data[0].shape
+ return data, seg, properties
+
+ @staticmethod
+ def crop_from_list_of_files(data_files, seg_file=None):
+ data, seg, properties = load_case_from_list_of_files(data_files, seg_file)
+ return ImageCropper.crop(data, properties, seg)
+
+ def load_crop_save(self, case, case_identifier, overwrite_existing=False):
+ try:
+ print(case_identifier)
+ if overwrite_existing \
+ or (not os.path.isfile(os.path.join(self.output_folder, "%s.npz" % case_identifier))
+ or not os.path.isfile(os.path.join(self.output_folder, "%s.pkl" % case_identifier))):
+
+ data, seg, properties = self.crop_from_list_of_files(case[:-1], case[-1])
+
+ all_data = np.vstack((data, seg))
+ np.savez_compressed(os.path.join(self.output_folder, "%s.npz" % case_identifier), data=all_data)
+ with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
+ pickle.dump(properties, f)
+ except Exception as e:
+ print("Exception in", case_identifier, ":")
+ print(e)
+ raise e
+
+ def get_list_of_cropped_files(self):
+ return subfiles(self.output_folder, join=True, suffix=".npz")
+
+ def get_patient_identifiers_from_cropped_files(self):
+ return [i.split("/")[-1][:-4] for i in self.get_list_of_cropped_files()]
+
+ def run_cropping(self, list_of_files, overwrite_existing=False, output_folder=None):
+ """
+ also copied ground truth nifti segmentation into the preprocessed folder so that we can use them for evaluation
+ on the cluster
+ :param list_of_files: list of list of files [[PATIENTID_TIMESTEP_0000.nii.gz], [PATIENTID_TIMESTEP_0000.nii.gz]]
+ :param overwrite_existing:
+ :param output_folder:
+ :return:
+ """
+ if output_folder is not None:
+ self.output_folder = output_folder
+
+ output_folder_gt = os.path.join(self.output_folder, "gt_segmentations")
+ maybe_mkdir_p(output_folder_gt)
+ for j, case in enumerate(list_of_files):
+ if case[-1] is not None:
+ shutil.copy(case[-1], output_folder_gt)
+
+ list_of_args = []
+ for j, case in enumerate(list_of_files):
+ case_identifier = get_case_identifier(case)
+ list_of_args.append((case, case_identifier, overwrite_existing))
+
+ p = Pool(self.num_threads)
+ p.starmap(self.load_crop_save, list_of_args)
+ p.close()
+ p.join()
+
+ def load_properties(self, case_identifier):
+ with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'rb') as f:
+ properties = pickle.load(f)
+ return properties
+
+ def save_properties(self, case_identifier, properties):
+ with open(os.path.join(self.output_folder, "%s.pkl" % case_identifier), 'wb') as f:
+ pickle.dump(properties, f)
diff --git a/nnunet/preprocessing/custom_preprocessors/__init__.py b/nnunet/preprocessing/custom_preprocessors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py b/nnunet/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py
new file mode 100644
index 0000000000000000000000000000000000000000..b07273f6c445bdfed783858977a6079a1c036535
--- /dev/null
+++ b/nnunet/preprocessing/custom_preprocessors/preprocessor_scale_RGB_to_0_1.py
@@ -0,0 +1,66 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from nnunet.preprocessing.preprocessing import PreprocessorFor2D, resample_patient
+
+
+class GenericPreprocessor_scale_uint8_to_0_1(PreprocessorFor2D):
+ """
+ For RGB images with a value range of [0, 255]. This preprocessor overwrites the default normalization scheme by
+ normalizing intensity values through a simple division by 255 which rescales them to [0, 1]
+
+ NOTE THAT THIS INHERITS FROM PreprocessorFor2D, SO ITS WRITTEN FOR 2D ONLY! WHEN CREATING A PREPROCESSOR FOR 3D
+ DATA, USE GenericPreprocessor AS PARENT!
+ """
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ ############ THIS PART IS IDENTICAL TO PARENT CLASS ################
+
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+ target_spacing[0] = original_spacing_transposed[0]
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
+ force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ print("normalization...")
+
+ ############ HERE IS WHERE WE START CHANGING THINGS!!!!!!!################
+
+ # this is where the normalization takes place. We ignore use_nonzero_mask and normalization_scheme_per_modality
+ for c in range(len(data)):
+ data[c] = data[c].astype(np.float32) / 255.
+ return data, seg, properties
\ No newline at end of file
diff --git a/nnunet/preprocessing/preprocessing.py b/nnunet/preprocessing/preprocessing.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dbc6d9470d8f2efea42debda8a517aa7307bf1a
--- /dev/null
+++ b/nnunet/preprocessing/preprocessing.py
@@ -0,0 +1,950 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+from copy import deepcopy
+
+from batchgenerators.augmentations.utils import resize_segmentation
+from nnunet.configuration import default_num_threads, RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD
+from nnunet.preprocessing.cropping import get_case_identifier_from_npz, ImageCropper
+from skimage.transform import resize
+from scipy.ndimage.interpolation import map_coordinates
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from multiprocessing.pool import Pool
+
+
+def get_do_separate_z(spacing, anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD):
+ do_separate_z = (np.max(spacing) / np.min(spacing)) > anisotropy_threshold
+ return do_separate_z
+
+
+def get_lowres_axis(new_spacing):
+ axis = np.where(max(new_spacing) / np.array(new_spacing) == 1)[0] # find which axis is anisotropic
+ return axis
+
+
+def resample_patient(data, seg, original_spacing, target_spacing, order_data=3, order_seg=0, force_separate_z=False,
+ order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD):
+ """
+ :param data:
+ :param seg:
+ :param original_spacing:
+ :param target_spacing:
+ :param order_data:
+ :param order_seg:
+ :param force_separate_z: if None then we dynamically decide how to resample along z, if True/False then always
+ /never resample along z separately
+ :param order_z_seg: only applies if do_separate_z is True
+ :param order_z_data: only applies if do_separate_z is True
+ :param separate_z_anisotropy_threshold: if max_spacing > separate_z_anisotropy_threshold * min_spacing (per axis)
+ then resample along lowres axis with order_z_data/order_z_seg instead of order_data/order_seg
+
+ :return:
+ """
+ assert not ((data is None) and (seg is None))
+ if data is not None:
+ assert len(data.shape) == 4, "data must be c x y z"
+ if seg is not None:
+ assert len(seg.shape) == 4, "seg must be c x y z"
+
+ if data is not None:
+ shape = np.array(data[0].shape)
+ else:
+ shape = np.array(seg[0].shape)
+ new_shape = np.round(((np.array(original_spacing) / np.array(target_spacing)).astype(float) * shape)).astype(int)
+
+ if force_separate_z is not None:
+ do_separate_z = force_separate_z
+ if force_separate_z:
+ axis = get_lowres_axis(original_spacing)
+ else:
+ axis = None
+ else:
+ if get_do_separate_z(original_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(original_spacing)
+ elif get_do_separate_z(target_spacing, separate_z_anisotropy_threshold):
+ do_separate_z = True
+ axis = get_lowres_axis(target_spacing)
+ else:
+ do_separate_z = False
+ axis = None
+
+ if axis is not None:
+ if len(axis) == 3:
+ # every axis has the spacing, this should never happen, why is this code here?
+ do_separate_z = False
+ elif len(axis) == 2:
+ # this happens for spacings like (0.24, 1.25, 1.25) for example. In that case we do not want to resample
+ # separately in the out of plane axis
+ do_separate_z = False
+ else:
+ pass
+
+ if data is not None:
+ data_reshaped = resample_data_or_seg(data, new_shape, False, axis, order_data, do_separate_z,
+ order_z=order_z_data)
+ else:
+ data_reshaped = None
+ if seg is not None:
+ seg_reshaped = resample_data_or_seg(seg, new_shape, True, axis, order_seg, do_separate_z, order_z=order_z_seg)
+ else:
+ seg_reshaped = None
+ return data_reshaped, seg_reshaped
+
+
+def resample_data_or_seg(data, new_shape, is_seg, axis=None, order=3, do_separate_z=False, order_z=0):
+ """
+ separate_z=True will resample with order 0 along z
+ :param data:
+ :param new_shape:
+ :param is_seg:
+ :param axis:
+ :param order:
+ :param do_separate_z:
+ :param order_z: only applies if do_separate_z is True
+ :return:
+ """
+ assert len(data.shape) == 4, "data must be (c, x, y, z)"
+ assert len(new_shape) == len(data.shape) - 1
+ if is_seg:
+ resize_fn = resize_segmentation
+ kwargs = OrderedDict()
+ else:
+ resize_fn = resize
+ kwargs = {'mode': 'edge', 'anti_aliasing': False}
+ dtype_data = data.dtype
+ shape = np.array(data[0].shape)
+ new_shape = np.array(new_shape)
+ if np.any(shape != new_shape):
+ data = data.astype(float)
+ if do_separate_z:
+ print("separate z, order in z is", order_z, "order inplane is", order)
+ assert len(axis) == 1, "only one anisotropic axis supported"
+ axis = axis[0]
+ if axis == 0:
+ new_shape_2d = new_shape[1:]
+ elif axis == 1:
+ new_shape_2d = new_shape[[0, 2]]
+ else:
+ new_shape_2d = new_shape[:-1]
+
+ reshaped_final_data = []
+ for c in range(data.shape[0]):
+ reshaped_data = []
+ for slice_id in range(shape[axis]):
+ if axis == 0:
+ reshaped_data.append(resize_fn(data[c, slice_id], new_shape_2d, order, **kwargs).astype(dtype_data))
+ elif axis == 1:
+ reshaped_data.append(resize_fn(data[c, :, slice_id], new_shape_2d, order, **kwargs).astype(dtype_data))
+ else:
+ reshaped_data.append(resize_fn(data[c, :, :, slice_id], new_shape_2d, order, **kwargs).astype(dtype_data))
+ reshaped_data = np.stack(reshaped_data, axis)
+ if shape[axis] != new_shape[axis]:
+
+ # The following few lines are blatantly copied and modified from sklearn's resize()
+ rows, cols, dim = new_shape[0], new_shape[1], new_shape[2]
+ orig_rows, orig_cols, orig_dim = reshaped_data.shape
+
+ row_scale = float(orig_rows) / rows
+ col_scale = float(orig_cols) / cols
+ dim_scale = float(orig_dim) / dim
+
+ map_rows, map_cols, map_dims = np.mgrid[:rows, :cols, :dim]
+ map_rows = row_scale * (map_rows + 0.5) - 0.5
+ map_cols = col_scale * (map_cols + 0.5) - 0.5
+ map_dims = dim_scale * (map_dims + 0.5) - 0.5
+
+ coord_map = np.array([map_rows, map_cols, map_dims])
+ if not is_seg or order_z == 0:
+ reshaped_final_data.append(map_coordinates(reshaped_data, coord_map, order=order_z,
+ mode='nearest')[None].astype(dtype_data))
+ else:
+ unique_labels = np.unique(reshaped_data)
+ reshaped = np.zeros(new_shape, dtype=dtype_data)
+
+ for i, cl in enumerate(unique_labels):
+ reshaped_multihot = np.round(
+ map_coordinates((reshaped_data == cl).astype(float), coord_map, order=order_z,
+ mode='nearest'))
+ reshaped[reshaped_multihot > 0.5] = cl
+ reshaped_final_data.append(reshaped[None].astype(dtype_data))
+ else:
+ reshaped_final_data.append(reshaped_data[None].astype(dtype_data))
+ reshaped_final_data = np.vstack(reshaped_final_data)
+ else:
+ print("no separate z, order", order)
+ reshaped = []
+ for c in range(data.shape[0]):
+ reshaped.append(resize_fn(data[c], new_shape, order, **kwargs)[None].astype(dtype_data))
+ reshaped_final_data = np.vstack(reshaped)
+ return reshaped_final_data.astype(dtype_data)
+ else:
+ print("no resampling necessary")
+ return data
+
+
+class GenericPreprocessor(object):
+ def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None):
+ """
+
+ :param normalization_scheme_per_modality: dict {0:'nonCT'}
+ :param use_nonzero_mask: {0:False}
+ :param intensityproperties:
+ """
+ self.transpose_forward = transpose_forward
+ self.intensityproperties = intensityproperties
+ self.normalization_scheme_per_modality = normalization_scheme_per_modality
+ self.use_nonzero_mask = use_nonzero_mask
+
+ self.resample_separate_z_anisotropy_threshold = RESAMPLING_SEPARATE_Z_ANISO_THRESHOLD
+ self.resample_order_data = 3
+ self.resample_order_seg = 1
+
+ @staticmethod
+ def load_cropped(cropped_output_dir, case_identifier):
+ all_data = np.load(os.path.join(cropped_output_dir, "%s.npz" % case_identifier))['data']
+ data = all_data[:-1].astype(np.float32)
+ seg = all_data[-1:]
+ with open(os.path.join(cropped_output_dir, "%s.pkl" % case_identifier), 'rb') as f:
+ properties = pickle.load(f)
+ return data, seg, properties
+
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ """
+ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
+ (spacing etc)
+ :param data:
+ :param target_spacing:
+ :param properties:
+ :param seg:
+ :param force_separate_z:
+ :return:
+ """
+
+ # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
+ # data, seg are already transposed. Double check this using the properties
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+
+ # remove nans
+ data[np.isnan(data)] = 0
+
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing,
+ self.resample_order_data, self.resample_order_seg,
+ force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ print('no intensity normalization')
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ else:
+ mn = data[c].mean()
+ std = data[c].std()
+ # print(data[c].shape, data[c].dtype, mn, std)
+ data[c] = (data[c] - mn) / (std + 1e-8)
+ return data, seg, properties
+
+ def preprocess_test_case(self, data_files, target_spacing, seg_file=None, force_separate_z=None):
+ data, seg, properties = ImageCropper.crop_from_list_of_files(data_files, seg_file)
+
+ data = data.transpose((0, *[i + 1 for i in self.transpose_forward]))
+ seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward]))
+
+ data, seg, properties = self.resample_and_normalize(data, target_spacing, properties, seg,
+ force_separate_z=force_separate_z)
+ return data.astype(np.float32), seg, properties
+
+ def _run_internal(self, target_spacing, case_identifier, output_folder_stage, cropped_output_dir, force_separate_z,
+ all_classes):
+ data, seg, properties = self.load_cropped(cropped_output_dir, case_identifier)
+
+ data = data.transpose((0, *[i + 1 for i in self.transpose_forward]))
+ seg = seg.transpose((0, *[i + 1 for i in self.transpose_forward]))
+
+ data, seg, properties = self.resample_and_normalize(data, target_spacing,
+ properties, seg, force_separate_z)
+
+ all_data = np.vstack((data, seg)).astype(np.float32)
+
+ # we need to find out where the classes are and sample some random locations
+ # let's do 10.000 samples per class
+ # seed this for reproducibility!
+ num_samples = 10000
+ min_percent_coverage = 0.01 # at least 1% of the class voxels need to be selected, otherwise it may be too sparse
+ rndst = np.random.RandomState(1234)
+ class_locs = {}
+ for c in all_classes:
+ all_locs = np.argwhere(all_data[-1] == c)
+ if len(all_locs) == 0:
+ class_locs[c] = []
+ continue
+ target_num_samples = min(num_samples, len(all_locs))
+ target_num_samples = max(target_num_samples, int(np.ceil(len(all_locs) * min_percent_coverage)))
+
+ selected = all_locs[rndst.choice(len(all_locs), target_num_samples, replace=False)]
+ class_locs[c] = selected
+ print(c, target_num_samples)
+ properties['class_locations'] = class_locs
+
+ print("saving: ", os.path.join(output_folder_stage, "%s.npz" % case_identifier))
+ np.savez_compressed(os.path.join(output_folder_stage, "%s.npz" % case_identifier),
+ data=all_data.astype(np.float32))
+ with open(os.path.join(output_folder_stage, "%s.pkl" % case_identifier), 'wb') as f:
+ pickle.dump(properties, f)
+
+ def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
+ num_threads=default_num_threads, force_separate_z=None):
+ """
+
+ :param target_spacings: list of lists [[1.25, 1.25, 5]]
+ :param input_folder_with_cropped_npz: dim: c, x, y, z | npz_file['data'] np.savez_compressed(fname.npz, data=arr)
+ :param output_folder:
+ :param num_threads:
+ :param force_separate_z: None
+ :return:
+ """
+ print("Initializing to run preprocessing")
+ print("npz folder:", input_folder_with_cropped_npz)
+ print("output_folder:", output_folder)
+ list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True)
+ maybe_mkdir_p(output_folder)
+ num_stages = len(target_spacings)
+ if not isinstance(num_threads, (list, tuple, np.ndarray)):
+ num_threads = [num_threads] * num_stages
+
+ assert len(num_threads) == num_stages
+
+ # we need to know which classes are present in this dataset so that we can precompute where these classes are
+ # located. This is needed for oversampling foreground
+ all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes']
+
+ for i in range(num_stages):
+ all_args = []
+ output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i)
+ maybe_mkdir_p(output_folder_stage)
+ spacing = target_spacings[i]
+ for j, case in enumerate(list_of_cropped_npz_files):
+ case_identifier = get_case_identifier_from_npz(case)
+ args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes
+ all_args.append(args)
+ p = Pool(num_threads[i])
+ p.starmap(self._run_internal, all_args)
+ p.close()
+ p.join()
+
+
+class GenericPreprocessor_linearResampling(GenericPreprocessor):
+ def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list),
+ intensityproperties=None):
+ super().__init__(normalization_scheme_per_modality, use_nonzero_mask, transpose_forward, intensityproperties)
+ self.resample_order_data = 1
+ self.resample_order_seg = 1
+
+
+class Preprocessor3DDifferentResampling(GenericPreprocessor):
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ """
+ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
+ (spacing etc)
+ :param data:
+ :param target_spacing:
+ :param properties:
+ :param seg:
+ :param force_separate_z:
+ :return:
+ """
+
+ # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
+ # data, seg are already transposed. Double check this using the properties
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+
+ # remove nans
+ data[np.isnan(data)] = 0
+
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
+ force_separate_z=force_separate_z, order_z_data=3, order_z_seg=1,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ return data, seg, properties
+
+
+class Preprocessor3DBetterResampling(GenericPreprocessor):
+ """
+ This preprocessor always uses force_separate_z=False. It does resampling to the target spacing with third
+ order spline for data (just like GenericPreprocessor) and seg (unlike GenericPreprocessor). It never does separate
+ resampling in z.
+ """
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=False):
+ """
+ data and seg must already have been transposed by transpose_forward. properties are the un-transposed values
+ (spacing etc)
+ :param data:
+ :param target_spacing:
+ :param properties:
+ :param seg:
+ :param force_separate_z:
+ :return:
+ """
+ if force_separate_z is not False:
+ print("WARNING: Preprocessor3DBetterResampling always uses force_separate_z=False. "
+ "You specified %s. Your choice is overwritten" % str(force_separate_z))
+ force_separate_z = False
+
+ # be safe
+ assert force_separate_z is False
+
+ # target_spacing is already transposed, properties["original_spacing"] is not so we need to transpose it!
+ # data, seg are already transposed. Double check this using the properties
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+
+ # remove nans
+ data[np.isnan(data)] = 0
+
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 3,
+ force_separate_z=force_separate_z, order_z_data=99999, order_z_seg=99999,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ return data, seg, properties
+
+
+class PreprocessorFor2D(GenericPreprocessor):
+ def __init__(self, normalization_scheme_per_modality, use_nonzero_mask, transpose_forward: (tuple, list), intensityproperties=None):
+ super(PreprocessorFor2D, self).__init__(normalization_scheme_per_modality, use_nonzero_mask,
+ transpose_forward, intensityproperties)
+
+ def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
+ num_threads=default_num_threads, force_separate_z=None):
+ print("Initializing to run preprocessing")
+ print("npz folder:", input_folder_with_cropped_npz)
+ print("output_folder:", output_folder)
+ list_of_cropped_npz_files = subfiles(input_folder_with_cropped_npz, True, None, ".npz", True)
+ assert len(list_of_cropped_npz_files) != 0, "set list of files first"
+ maybe_mkdir_p(output_folder)
+ all_args = []
+ num_stages = len(target_spacings)
+
+ # we need to know which classes are present in this dataset so that we can precompute where these classes are
+ # located. This is needed for oversampling foreground
+ all_classes = load_pickle(join(input_folder_with_cropped_npz, 'dataset_properties.pkl'))['all_classes']
+
+ for i in range(num_stages):
+ output_folder_stage = os.path.join(output_folder, data_identifier + "_stage%d" % i)
+ maybe_mkdir_p(output_folder_stage)
+ spacing = target_spacings[i]
+ for j, case in enumerate(list_of_cropped_npz_files):
+ case_identifier = get_case_identifier_from_npz(case)
+ args = spacing, case_identifier, output_folder_stage, input_folder_with_cropped_npz, force_separate_z, all_classes
+ all_args.append(args)
+ p = Pool(num_threads)
+ p.starmap(self._run_internal, all_args)
+ p.close()
+ p.join()
+
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+ target_spacing[0] = original_spacing_transposed[0]
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
+ force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ print("normalization...")
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ print("normalization done")
+ return data, seg, properties
+
+
+class PreprocessorFor2D_edgeLength512(PreprocessorFor2D):
+ target_edge_size = 512
+
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+ data_shape = data.shape[-2:]
+ smaller_edge = min(data_shape)
+ target_edge_size = self.target_edge_size
+ scale_factor = target_edge_size / smaller_edge
+ new_shape = [1] + [int(np.round(i * scale_factor)) for i in data_shape]
+ print(new_shape)
+
+ data = resample_data_or_seg(data, new_shape, False, None, 3, False, 0)
+ seg = resample_data_or_seg(seg, new_shape, True, None, 1, False, 0)
+
+ after = {
+ 'spacing': 'None',
+ 'data.shape (data is resampled)': data.shape
+ }
+ print("before:", before, "\nafter: ", after, "\n")
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ print("normalization...")
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ print("normalization done")
+ return data, seg, properties
+
+
+class PreprocessorFor2D_edgeLength768(PreprocessorFor2D_edgeLength512):
+ target_edge_size = 768
+
+
+class PreprocessorFor3D_LeaveOriginalZSpacing(GenericPreprocessor):
+ """
+ 3d_lowres and 3d_fullres are not resampled along z!
+ """
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ """
+ if target_spacing[0] is None or nan we use original_spacing_transposed[0] (no resampling along z)
+ :param data:
+ :param target_spacing:
+ :param properties:
+ :param seg:
+ :param force_separate_z:
+ :return:
+ """
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+
+ # remove nans
+ data[np.isnan(data)] = 0
+ target_spacing = deepcopy(target_spacing)
+ if target_spacing[0] is None or np.isnan(target_spacing[0]):
+ target_spacing[0] = original_spacing_transposed[0]
+ #print(target_spacing, original_spacing_transposed)
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
+ force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ st = "before:" + str(before) + '\nafter' + str(after) + "\n"
+ print(st)
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ return data, seg, properties
+
+ def run(self, target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
+ num_threads=default_num_threads, force_separate_z=None):
+ for i in range(len(target_spacings)):
+ target_spacings[i][0] = None
+ super().run(target_spacings, input_folder_with_cropped_npz, output_folder, data_identifier,
+ default_num_threads, force_separate_z)
+
+
+class PreprocessorFor3D_NoResampling(GenericPreprocessor):
+ def resample_and_normalize(self, data, target_spacing, properties, seg=None, force_separate_z=None):
+ """
+ if target_spacing[0] is None or nan we use original_spacing_transposed[0] (no resampling along z)
+ :param data:
+ :param target_spacing:
+ :param properties:
+ :param seg:
+ :param force_separate_z:
+ :return:
+ """
+ original_spacing_transposed = np.array(properties["original_spacing"])[self.transpose_forward]
+ before = {
+ 'spacing': properties["original_spacing"],
+ 'spacing_transposed': original_spacing_transposed,
+ 'data.shape (data is transposed)': data.shape
+ }
+
+ # remove nans
+ data[np.isnan(data)] = 0
+ target_spacing = deepcopy(original_spacing_transposed)
+ #print(target_spacing, original_spacing_transposed)
+ data, seg = resample_patient(data, seg, np.array(original_spacing_transposed), target_spacing, 3, 1,
+ force_separate_z=force_separate_z, order_z_data=0, order_z_seg=0,
+ separate_z_anisotropy_threshold=self.resample_separate_z_anisotropy_threshold)
+ after = {
+ 'spacing': target_spacing,
+ 'data.shape (data is resampled)': data.shape
+ }
+ st = "before:" + str(before) + '\nafter' + str(after) + "\n"
+ print(st)
+
+ if seg is not None: # hippocampus 243 has one voxel with -2 as label. wtf?
+ seg[seg < -1] = 0
+
+ properties["size_after_resampling"] = data[0].shape
+ properties["spacing_after_resampling"] = target_spacing
+ use_nonzero_mask = self.use_nonzero_mask
+
+ assert len(self.normalization_scheme_per_modality) == len(data), "self.normalization_scheme_per_modality " \
+ "must have as many entries as data has " \
+ "modalities"
+ assert len(self.use_nonzero_mask) == len(data), "self.use_nonzero_mask must have as many entries as data" \
+ " has modalities"
+
+ for c in range(len(data)):
+ scheme = self.normalization_scheme_per_modality[c]
+ if scheme == "CT":
+ # clip to lb and ub from train data foreground and use foreground mn and sd from training data
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ mean_intensity = self.intensityproperties[c]['mean']
+ std_intensity = self.intensityproperties[c]['sd']
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ data[c] = (data[c] - mean_intensity) / std_intensity
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == "CT2":
+ # clip to lb and ub from train data foreground, use mn and sd form each case for normalization
+ assert self.intensityproperties is not None, "ERROR: if there is a CT then we need intensity properties"
+ lower_bound = self.intensityproperties[c]['percentile_00_5']
+ upper_bound = self.intensityproperties[c]['percentile_99_5']
+ mask = (data[c] > lower_bound) & (data[c] < upper_bound)
+ data[c] = np.clip(data[c], lower_bound, upper_bound)
+ mn = data[c][mask].mean()
+ sd = data[c][mask].std()
+ data[c] = (data[c] - mn) / sd
+ if use_nonzero_mask[c]:
+ data[c][seg[-1] < 0] = 0
+ elif scheme == 'noNorm':
+ pass
+ else:
+ if use_nonzero_mask[c]:
+ mask = seg[-1] >= 0
+ else:
+ mask = np.ones(seg.shape[1:], dtype=bool)
+ data[c][mask] = (data[c][mask] - data[c][mask].mean()) / (data[c][mask].std() + 1e-8)
+ data[c][mask == 0] = 0
+ return data, seg, properties
+
diff --git a/nnunet/preprocessing/sanity_checks.py b/nnunet/preprocessing/sanity_checks.py
new file mode 100644
index 0000000000000000000000000000000000000000..171af6634e475fe9f8c97baf721a0131bb2ec013
--- /dev/null
+++ b/nnunet/preprocessing/sanity_checks.py
@@ -0,0 +1,274 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from multiprocessing import Pool
+
+import SimpleITK as sitk
+import nibabel as nib
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.configuration import default_num_threads
+
+
+def verify_all_same_orientation(folder):
+ """
+ This should run after cropping
+ :param folder:
+ :return:
+ """
+ nii_files = subfiles(folder, suffix=".nii.gz", join=True)
+ orientations = []
+ for n in nii_files:
+ img = nib.load(n)
+ affine = img.affine
+ orientation = nib.aff2axcodes(affine)
+ orientations.append(orientation)
+ # now we need to check whether they are all the same
+ orientations = np.array(orientations)
+ unique_orientations = np.unique(orientations, axis=0)
+ all_same = len(unique_orientations) == 1
+ return all_same, unique_orientations
+
+
+def verify_same_geometry(img_1: sitk.Image, img_2: sitk.Image):
+ ori1, spacing1, direction1, size1 = img_1.GetOrigin(), img_1.GetSpacing(), img_1.GetDirection(), img_1.GetSize()
+ ori2, spacing2, direction2, size2 = img_2.GetOrigin(), img_2.GetSpacing(), img_2.GetDirection(), img_2.GetSize()
+
+ same_ori = np.all(np.isclose(ori1, ori2))
+ if not same_ori:
+ print("the origin does not match between the images:")
+ print(ori1)
+ print(ori2)
+
+ same_spac = np.all(np.isclose(spacing1, spacing2))
+ if not same_spac:
+ print("the spacing does not match between the images")
+ print(spacing1)
+ print(spacing2)
+
+ same_dir = np.all(np.isclose(direction1, direction2))
+ if not same_dir:
+ print("the direction does not match between the images")
+ print(direction1)
+ print(direction2)
+
+ same_size = np.all(np.isclose(size1, size2))
+ if not same_size:
+ print("the size does not match between the images")
+ print(size1)
+ print(size2)
+
+ if same_ori and same_spac and same_dir and same_size:
+ return True
+ else:
+ return False
+
+
+def verify_contains_only_expected_labels(itk_img: str, valid_labels: (tuple, list)):
+ img_npy = sitk.GetArrayFromImage(sitk.ReadImage(itk_img))
+ uniques = np.unique(img_npy)
+ invalid_uniques = [i for i in uniques if i not in valid_labels]
+ if len(invalid_uniques) == 0:
+ r = True
+ else:
+ r = False
+ return r, invalid_uniques
+
+
+def verify_dataset_integrity(folder):
+ """
+ folder needs the imagesTr, imagesTs and labelsTr subfolders. There also needs to be a dataset.json
+ checks if all training cases and labels are present
+ checks if all test cases (if any) are present
+ for each case, checks whether all modalities apre present
+ for each case, checks whether the pixel grids are aligned
+ checks whether the labels really only contain values they should
+ :param folder:
+ :return:
+ """
+ assert isfile(join(folder, "dataset.json")), "There needs to be a dataset.json file in folder, folder=%s" % folder
+ assert isdir(join(folder, "imagesTr")), "There needs to be a imagesTr subfolder in folder, folder=%s" % folder
+ assert isdir(join(folder, "labelsTr")), "There needs to be a labelsTr subfolder in folder, folder=%s" % folder
+ dataset = load_json(join(folder, "dataset.json"))
+ training_cases = dataset['training']
+ num_modalities = len(dataset['modality'].keys())
+ test_cases = dataset['test']
+ expected_train_identifiers = [i['image'].split("/")[-1][:-7] for i in training_cases]
+ expected_test_identifiers = [i.split("/")[-1][:-7] for i in test_cases]
+
+ ## check training set
+ nii_files_in_imagesTr = subfiles((join(folder, "imagesTr")), suffix=".nii.gz", join=False)
+ nii_files_in_labelsTr = subfiles((join(folder, "labelsTr")), suffix=".nii.gz", join=False)
+
+ label_files = []
+ geometries_OK = True
+ has_nan = False
+
+ # check all cases
+ if len(expected_train_identifiers) != len(np.unique(expected_train_identifiers)): raise RuntimeError("found duplicate training cases in dataset.json")
+
+ print("Verifying training set")
+ for c in expected_train_identifiers:
+ print("checking case", c)
+ # check if all files are present
+ expected_label_file = join(folder, "labelsTr", c + ".nii.gz")
+ label_files.append(expected_label_file)
+ expected_image_files = [join(folder, "imagesTr", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)]
+ assert isfile(expected_label_file), "could not find label file for case %s. Expected file: \n%s" % (
+ c, expected_label_file)
+ assert all([isfile(i) for i in
+ expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % (
+ c, expected_image_files)
+
+ # verify that all modalities and the label have the same shape and geometry.
+ label_itk = sitk.ReadImage(expected_label_file)
+
+ nans_in_seg = np.any(np.isnan(sitk.GetArrayFromImage(label_itk)))
+ has_nan = has_nan | nans_in_seg
+ if nans_in_seg:
+ print("There are NAN values in segmentation %s" % expected_label_file)
+
+ images_itk = [sitk.ReadImage(i) for i in expected_image_files]
+ for i, img in enumerate(images_itk):
+ nans_in_image = np.any(np.isnan(sitk.GetArrayFromImage(img)))
+ has_nan = has_nan | nans_in_image
+ same_geometry = verify_same_geometry(img, label_itk)
+ if not same_geometry:
+ geometries_OK = False
+ print("The geometry of the image %s does not match the geometry of the label file. The pixel arrays "
+ "will not be aligned and nnU-Net cannot use this data. Please make sure your image modalities "
+ "are coregistered and have the same geometry as the label" % expected_image_files[0][:-12])
+ if nans_in_image:
+ print("There are NAN values in image %s" % expected_image_files[i])
+
+ # now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
+ for i in expected_image_files:
+ nii_files_in_imagesTr.remove(os.path.basename(i))
+ nii_files_in_labelsTr.remove(os.path.basename(expected_label_file))
+
+ # check for stragglers
+ assert len(
+ nii_files_in_imagesTr) == 0, "there are training cases in imagesTr that are not listed in dataset.json: %s" % nii_files_in_imagesTr
+ assert len(
+ nii_files_in_labelsTr) == 0, "there are training cases in labelsTr that are not listed in dataset.json: %s" % nii_files_in_labelsTr
+
+ # verify that only properly declared values are present in the labels
+ print("Verifying label values")
+ expected_labels = list(int(i) for i in dataset['labels'].keys())
+ expected_labels.sort()
+
+ # check if labels are in consecutive order
+ assert expected_labels[0] == 0, 'The first label must be 0 and maps to the background'
+ labels_valid_consecutive = np.ediff1d(expected_labels) == 1
+ assert all(labels_valid_consecutive), f'Labels must be in consecutive order (0, 1, 2, ...). The labels {np.array(expected_labels)[1:][~labels_valid_consecutive]} do not satisfy this restriction'
+
+ p = Pool(default_num_threads)
+ results = p.starmap(verify_contains_only_expected_labels, zip(label_files, [expected_labels] * len(label_files)))
+ p.close()
+ p.join()
+
+ fail = False
+ print("Expected label values are", expected_labels)
+ for i, r in enumerate(results):
+ if not r[0]:
+ print("Unexpected labels found in file %s. Found these unexpected values (they should not be there) %s" % (
+ label_files[i], r[1]))
+ fail = True
+
+ if fail:
+ raise AssertionError(
+ "Found unexpected labels in the training dataset. Please correct that or adjust your dataset.json accordingly")
+ else:
+ print("Labels OK")
+
+ # check test set, but only if there actually is a test set
+ if len(expected_test_identifiers) > 0:
+ print("Verifying test set")
+ nii_files_in_imagesTs = subfiles((join(folder, "imagesTs")), suffix=".nii.gz", join=False)
+
+ for c in expected_test_identifiers:
+ # check if all files are present
+ expected_image_files = [join(folder, "imagesTs", c + "_%04.0d.nii.gz" % i) for i in range(num_modalities)]
+ assert all([isfile(i) for i in
+ expected_image_files]), "some image files are missing for case %s. Expected files:\n %s" % (
+ c, expected_image_files)
+
+ # verify that all modalities and the label have the same geometry. We use the affine for this
+ if num_modalities > 1:
+ images_itk = [sitk.ReadImage(i) for i in expected_image_files]
+ reference_img = images_itk[0]
+
+ for i, img in enumerate(images_itk[1:]):
+ assert verify_same_geometry(img, reference_img), "The modalities of the image %s do not seem to be " \
+ "registered. Please coregister your modalities." % (
+ expected_image_files[i])
+
+ # now remove checked files from the lists nii_files_in_imagesTr and nii_files_in_labelsTr
+ for i in expected_image_files:
+ nii_files_in_imagesTs.remove(os.path.basename(i))
+ assert len(
+ nii_files_in_imagesTs) == 0, "there are training cases in imagesTs that are not listed in dataset.json: %s" % nii_files_in_imagesTr
+
+ all_same, unique_orientations = verify_all_same_orientation(join(folder, "imagesTr"))
+ if not all_same:
+ print(
+ "WARNING: Not all images in the dataset have the same axis ordering. We very strongly recommend you correct that by reorienting the data. fslreorient2std should do the trick")
+ # save unique orientations to dataset.json
+ if not geometries_OK:
+ raise Warning("GEOMETRY MISMATCH FOUND! CHECK THE TEXT OUTPUT! This does not cause an error at this point but you should definitely check whether your geometries are alright!")
+ else:
+ print("Dataset OK")
+
+ if has_nan:
+ raise RuntimeError("Some images have nan values in them. This will break the training. See text output above to see which ones")
+
+
+def reorient_to_RAS(img_fname: str, output_fname: str = None):
+ img = nib.load(img_fname)
+ canonical_img = nib.as_closest_canonical(img)
+ if output_fname is None:
+ output_fname = img_fname
+ nib.save(canonical_img, output_fname)
+
+
+if __name__ == "__main__":
+ # investigate geometry issues
+ import SimpleITK as sitk
+
+ # load image
+ gt_itk = sitk.ReadImage(
+ "/media/fabian/Results/nnUNet/3d_fullres/Task064_KiTS_labelsFixed/nnUNetTrainerV2__nnUNetPlansv2.1/gt_niftis/case_00085.nii.gz")
+
+ # get numpy array
+ pred_npy = sitk.GetArrayFromImage(gt_itk)
+
+ # create new image from numpy array
+ prek_itk_new = sitk.GetImageFromArray(pred_npy)
+ # copy geometry
+ prek_itk_new.CopyInformation(gt_itk)
+ # prek_itk_new = copy_geometry(prek_itk_new, gt_itk)
+
+ # save
+ sitk.WriteImage(prek_itk_new, "test.mnc")
+
+ # load images in nib
+ gt = nib.load(
+ "/media/fabian/Results/nnUNet/3d_fullres/Task064_KiTS_labelsFixed/nnUNetTrainerV2__nnUNetPlansv2.1/gt_niftis/case_00085.nii.gz")
+ pred_nib = nib.load("test.mnc")
+
+ new_img_sitk = sitk.ReadImage("test.mnc")
+
+ np1 = sitk.GetArrayFromImage(gt_itk)
+ np2 = sitk.GetArrayFromImage(prek_itk_new)
diff --git a/nnunet/run/__init__.py b/nnunet/run/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/run/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/run/default_configuration.py b/nnunet/run/default_configuration.py
new file mode 100644
index 0000000000000000000000000000000000000000..968d904e56c3516ad1a31805f4dbc8d1f3032573
--- /dev/null
+++ b/nnunet/run/default_configuration.py
@@ -0,0 +1,80 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import nnunet
+from nnunet.paths import network_training_output_dir, preprocessing_output_dir, default_plans_identifier
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.experiment_planning.summarize_plans import summarize_plans
+from nnunet.training.model_restore import recursive_find_python_class
+
+
+def get_configuration_from_output_folder(folder):
+ # split off network_training_output_dir
+ folder = folder[len(network_training_output_dir):]
+ if folder.startswith("/"):
+ folder = folder[1:]
+
+ configuration, task, trainer_and_plans_identifier = folder.split("/")
+ trainer, plans_identifier = trainer_and_plans_identifier.split("__")
+ return configuration, task, trainer, plans_identifier
+
+
+def get_default_configuration(network, task, network_trainer, plans_identifier=default_plans_identifier,
+ search_in=(nnunet.__path__[0], "training", "network_training"),
+ base_module='nnunet.training.network_training'):
+ assert network in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres'], \
+ "network can only be one of the following: \'2d\', \'3d_lowres\', \'3d_fullres\', \'3d_cascade_fullres\'"
+
+ dataset_directory = join(preprocessing_output_dir, task)
+
+ if network == '2d':
+ plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_2D.pkl")
+ else:
+ plans_file = join(preprocessing_output_dir, task, plans_identifier + "_plans_3D.pkl")
+
+ plans = load_pickle(plans_file)
+ possible_stages = list(plans['plans_per_stage'].keys())
+
+ if (network == '3d_cascade_fullres' or network == "3d_lowres") and len(possible_stages) == 1:
+ raise RuntimeError("3d_lowres/3d_cascade_fullres only applies if there is more than one stage. This task does "
+ "not require the cascade. Run 3d_fullres instead")
+
+ if network == '2d' or network == "3d_lowres":
+ stage = 0
+ else:
+ stage = possible_stages[-1]
+
+ trainer_class = recursive_find_python_class([join(*search_in)], network_trainer,
+ current_module=base_module)
+
+ output_folder_name = join(network_training_output_dir, network, task, network_trainer + "__" + plans_identifier)
+
+ print("###############################################")
+ print("I am running the following nnUNet: %s" % network)
+ print("My trainer class is: ", trainer_class)
+ print("For that I will be using the following configuration:")
+ summarize_plans(plans_file)
+ print("I am using stage %d from these plans" % stage)
+
+ if (network == '2d' or len(possible_stages) > 1) and not network == '3d_lowres':
+ batch_dice = True
+ print("I am using batch dice + CE loss")
+ else:
+ batch_dice = False
+ print("I am using sample dice + CE loss")
+
+ print("\nI am using data from this folder: ", join(dataset_directory, plans['data_identifier']))
+ print("###############################################")
+ return plans_file, output_folder_name, dataset_directory, batch_dice, stage, trainer_class
diff --git a/nnunet/run/load_pretrained_weights.py b/nnunet/run/load_pretrained_weights.py
new file mode 100644
index 0000000000000000000000000000000000000000..19379bff4806393f51aa3452b9ff8c9814814544
--- /dev/null
+++ b/nnunet/run/load_pretrained_weights.py
@@ -0,0 +1,62 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+
+
+def load_pretrained_weights(network, fname, verbose=False):
+ """
+ THIS DOES NOT TRANSFER SEGMENTATION HEADS!
+ """
+ saved_model = torch.load(fname)
+ pretrained_dict = saved_model['state_dict']
+
+ new_state_dict = {}
+
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ for k, value in pretrained_dict.items():
+ key = k
+ # remove module. prefix from DDP models
+ if key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ pretrained_dict = new_state_dict
+
+ model_dict = network.state_dict()
+ ok = True
+ for key, _ in model_dict.items():
+ if ('conv_blocks' in key):
+ if (key in pretrained_dict) and (model_dict[key].shape == pretrained_dict[key].shape):
+ continue
+ else:
+ ok = False
+ break
+
+ # filter unnecessary keys
+ if ok:
+ pretrained_dict = {k: v for k, v in pretrained_dict.items() if
+ (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)}
+ # 2. overwrite entries in the existing state dict
+ model_dict.update(pretrained_dict)
+ print("################### Loading pretrained weights from file ", fname, '###################')
+ if verbose:
+ print("Below is the list of overlapping blocks in pretrained model and nnUNet architecture:")
+ for key, _ in pretrained_dict.items():
+ print(key)
+ print("################### Done ###################")
+ network.load_state_dict(model_dict)
+ else:
+ raise RuntimeError("Pretrained weights are not compatible with the current network architecture")
+
diff --git a/nnunet/run/run_training.py b/nnunet/run/run_training.py
new file mode 100644
index 0000000000000000000000000000000000000000..b91db3fded73a3e4d711a3bc99f9fe65c5719468
--- /dev/null
+++ b/nnunet/run/run_training.py
@@ -0,0 +1,199 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.run.default_configuration import get_default_configuration
+from nnunet.paths import default_plans_identifier
+from nnunet.run.load_pretrained_weights import load_pretrained_weights
+from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("network")
+ parser.add_argument("network_trainer")
+ parser.add_argument("task", help="can be task name or task id")
+ parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
+ parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
+ action="store_true")
+ parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
+ action="store_true")
+ parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
+ default=default_plans_identifier, required=False)
+ parser.add_argument("--use_compressed_data", default=False, action="store_true",
+ help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
+ "is much more CPU and RAM intensive and should only be used if you know what you are "
+ "doing", required=False)
+ parser.add_argument("--deterministic",
+ help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
+ "this is not necessary. Deterministic training will make you overfit to some random seed. "
+ "Don't use that.",
+ required=False, default=False, action="store_true")
+ parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
+ "export npz files of "
+ "predicted segmentations "
+ "in the validation as well. "
+ "This is needed to run the "
+ "ensembling step so unless "
+ "you are developing nnUNet "
+ "you should enable this")
+ parser.add_argument("--find_lr", required=False, default=False, action="store_true",
+ help="not used here, just for fun")
+ parser.add_argument("--valbest", required=False, default=False, action="store_true",
+ help="hands off. This is not intended to be used")
+ parser.add_argument("--fp32", required=False, default=False, action="store_true",
+ help="disable mixed precision training and run old school fp32")
+ parser.add_argument("--val_folder", required=False, default="validation_raw",
+ help="name of the validation folder. No need to use this for most people")
+ parser.add_argument("--disable_saving", required=False, action='store_true',
+ help="If set nnU-Net will not save any parameter files (except a temporary checkpoint that "
+ "will be removed at the end of the training). Useful for development when you are "
+ "only interested in the results and want to save some disk space")
+ parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
+ help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
+ "closely observing the model performance on specific configurations. You do not need it "
+ "when applying nnU-Net because the postprocessing for this will be determined only once "
+ "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
+ "running postprocessing on each fold is computationally cheap, but some users have "
+ "reported issues with very large images. If your images are large (>600x600x600 voxels) "
+ "you should consider setting this flag.")
+ # parser.add_argument("--interp_order", required=False, default=3, type=int,
+ # help="order of interpolation for segmentations. Testing purpose only. Hands off")
+ # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
+ # help="order of interpolation along z if z is resampled separately. Testing purpose only. "
+ # "Hands off")
+ # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
+ # help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
+ parser.add_argument('--val_disable_overwrite', action='store_false', default=True,
+ help='Validation does not overwrite existing segmentations')
+ parser.add_argument('--disable_next_stage_pred', action='store_true', default=False,
+ help='do not predict next stage')
+ parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
+ help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
+ 'file, for example model_final_checkpoint.model). Will only be used when actually training. '
+ 'Optional. Beta. Use with caution.')
+
+ args = parser.parse_args()
+
+ task = args.task
+ fold = args.fold
+ network = args.network
+ network_trainer = args.network_trainer
+ validation_only = args.validation_only
+ plans_identifier = args.p
+ find_lr = args.find_lr
+ disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
+
+ use_compressed_data = args.use_compressed_data
+ decompress_data = not use_compressed_data
+
+ deterministic = args.deterministic
+ valbest = args.valbest
+
+ fp32 = args.fp32
+ run_mixed_precision = not fp32
+
+ val_folder = args.val_folder
+ # interp_order = args.interp_order
+ # interp_order_z = args.interp_order_z
+ # force_separate_z = args.force_separate_z
+
+ if not task.startswith("Task"):
+ task_id = int(task)
+ task = convert_id_to_task_name(task_id)
+
+ if fold == 'all':
+ pass
+ else:
+ fold = int(fold)
+
+ # if force_separate_z == "None":
+ # force_separate_z = None
+ # elif force_separate_z == "False":
+ # force_separate_z = False
+ # elif force_separate_z == "True":
+ # force_separate_z = True
+ # else:
+ # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
+
+ plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
+ trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
+
+ if trainer_class is None:
+ raise RuntimeError("Could not find trainer class in nnunet.training.network_training")
+
+ if network == "3d_cascade_fullres":
+ assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
+ "If running 3d_cascade_fullres then your " \
+ "trainer class must be derived from " \
+ "nnUNetTrainerCascadeFullRes"
+ else:
+ assert issubclass(trainer_class,
+ nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"
+
+ trainer = trainer_class(plans_file, fold, output_folder=output_folder_name, dataset_directory=dataset_directory,
+ batch_dice=batch_dice, stage=stage, unpack_data=decompress_data,
+ deterministic=deterministic,
+ fp16=run_mixed_precision)
+ if args.disable_saving:
+ trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
+ trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to
+ # self.best_val_eval_criterion_MA
+ trainer.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest. We need that in case
+ # the training chashes
+ trainer.save_latest_only = True # if false it will not store/overwrite _latest but separate files each
+
+ trainer.initialize(not validation_only)
+
+ if find_lr:
+ trainer.find_lr()
+ else:
+ if not validation_only:
+ if args.continue_training:
+ # -c was set, continue a previous training and ignore pretrained weights
+ trainer.load_latest_checkpoint()
+ elif (not args.continue_training) and (args.pretrained_weights is not None):
+ # we start a new training. If pretrained_weights are set, use them
+ load_pretrained_weights(trainer.network, args.pretrained_weights)
+ else:
+ # new training without pretraine weights, do nothing
+ pass
+
+ trainer.run_training()
+ else:
+ if valbest:
+ trainer.load_best_checkpoint(train=False)
+ else:
+ trainer.load_final_checkpoint(train=False)
+
+ trainer.network.eval()
+
+ # predict validation
+ trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
+ run_postprocessing_on_folds=not disable_postprocessing_on_folds,
+ overwrite=args.val_disable_overwrite)
+
+ if network == '3d_lowres' and not args.disable_next_stage_pred:
+ print("predicting segmentations for the next stage of the cascade")
+ predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/run/run_training_DDP.py b/nnunet/run/run_training_DDP.py
new file mode 100644
index 0000000000000000000000000000000000000000..80392de0e32292f72db6e21ce50dbe432c3900a4
--- /dev/null
+++ b/nnunet/run/run_training_DDP.py
@@ -0,0 +1,195 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.run.default_configuration import get_default_configuration
+from nnunet.paths import default_plans_identifier
+from nnunet.run.load_pretrained_weights import load_pretrained_weights
+from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("network")
+ parser.add_argument("network_trainer")
+ parser.add_argument("task", help="can be task name or task id")
+ parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
+ parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
+ action="store_true")
+ parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
+ action="store_true")
+ parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
+ default=default_plans_identifier, required=False)
+ parser.add_argument("--use_compressed_data", default=False, action="store_true",
+ help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
+ "is much more CPU and RAM intensive and should only be used if you know what you are "
+ "doing", required=False)
+ parser.add_argument("--deterministic",
+ help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
+ "this is not necessary. Deterministic training will make you overfit to some random seed. "
+ "Don't use that.",
+ required=False, default=False, action="store_true")
+ parser.add_argument("--local_rank", default=0, type=int)
+ parser.add_argument("--fp32", required=False, default=False, action="store_true",
+ help="disable mixed precision training and run old school fp32")
+ parser.add_argument("--dbs", required=False, default=False, action="store_true", help="distribute batch size. If "
+ "True then whatever "
+ "batch_size is in plans will "
+ "be distributed over DDP "
+ "models, if False then each "
+ "model will have batch_size "
+ "for a total of "
+ "GPUs*batch_size")
+ parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
+ "export npz files of "
+ "predicted segmentations "
+ "in the vlaidation as well. "
+ "This is needed to run the "
+ "ensembling step so unless "
+ "you are developing nnUNet "
+ "you should enable this")
+ parser.add_argument("--valbest", required=False, default=False, action="store_true", help="")
+ parser.add_argument("--find_lr", required=False, default=False, action="store_true", help="")
+ parser.add_argument("--val_folder", required=False, default="validation_raw",
+ help="name of the validation folder. No need to use this for most people")
+ parser.add_argument("--disable_saving", required=False, action='store_true',
+ help="If set nnU-Net will not save any parameter files. Useful for development when you are "
+ "only interested in the results and want to save some disk space")
+ parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
+ help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
+ "closely observing the model performance on specific configurations. You do not need it "
+ "when applying nnU-Net because the postprocessing for this will be determined only once "
+ "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
+ "running postprocessing on each fold is computationally cheap, but some users have "
+ "reported issues with very large images. If your images are large (>600x600x600 voxels) "
+ "you should consider setting this flag.")
+ # parser.add_argument("--interp_order", required=False, default=3, type=int,
+ # help="order of interpolation for segmentations. Testing purpose only. Hands off")
+ # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
+ # help="order of interpolation along z if z is resampled separately. Testing purpose only. "
+ # "Hands off")
+ # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
+ # help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
+ parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
+ help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
+ 'file, for example model_final_checkpoint.model). Will only be used when actually training. '
+ 'Optional. Beta. Use with caution.')
+
+ args = parser.parse_args()
+
+ task = args.task
+ fold = args.fold
+ network = args.network
+ network_trainer = args.network_trainer
+ validation_only = args.validation_only
+ plans_identifier = args.p
+ use_compressed_data = args.use_compressed_data
+ decompress_data = not use_compressed_data
+ deterministic = args.deterministic
+ valbest = args.valbest
+ find_lr = args.find_lr
+ val_folder = args.val_folder
+ # interp_order = args.interp_order
+ # interp_order_z = args.interp_order_z
+ # force_separate_z = args.force_separate_z
+ fp32 = args.fp32
+ disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
+
+ if not task.startswith("Task"):
+ task_id = int(task)
+ task = convert_id_to_task_name(task_id)
+
+ if fold == 'all':
+ pass
+ else:
+ fold = int(fold)
+ #
+ # if force_separate_z == "None":
+ # force_separate_z = None
+ # elif force_separate_z == "False":
+ # force_separate_z = False
+ # elif force_separate_z == "True":
+ # force_separate_z = True
+ # else:
+ # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
+
+ plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
+ trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
+
+ if trainer_class is None:
+ raise RuntimeError("Could not find trainer class in meddec.model_training")
+
+ if network == "3d_cascade_fullres":
+ assert issubclass(trainer_class, (nnUNetTrainerCascadeFullRes, nnUNetTrainerV2CascadeFullRes)), \
+ "If running 3d_cascade_fullres then your " \
+ "trainer class must be derived from " \
+ "nnUNetTrainerCascadeFullRes"
+ else:
+ assert issubclass(trainer_class,
+ nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"
+
+ trainer = trainer_class(plans_file, fold, local_rank=args.local_rank, output_folder=output_folder_name,
+ dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage,
+ unpack_data=decompress_data, deterministic=deterministic, fp16=not fp32,
+ distribute_batch_size=args.dbs)
+
+ if args.disable_saving:
+ trainer.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
+ trainer.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
+ trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
+ trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
+
+ trainer.initialize(not validation_only)
+
+ if find_lr:
+ trainer.find_lr()
+ else:
+ if not validation_only:
+ if args.continue_training:
+ # -c was set, continue a previous training and ignore pretrained weights
+ trainer.load_latest_checkpoint()
+ elif (not args.continue_training) and (args.pretrained_weights is not None):
+ # we start a new training. If pretrained_weights are set, use them
+ load_pretrained_weights(trainer.network, args.pretrained_weights)
+ else:
+ # new training without pretraine weights, do nothing
+ pass
+
+ trainer.run_training()
+ else:
+ if valbest:
+ trainer.load_best_checkpoint(train=False)
+ else:
+ trainer.load_final_checkpoint(train=False)
+
+ trainer.network.eval()
+
+ # predict validation
+ trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
+ run_postprocessing_on_folds=not disable_postprocessing_on_folds)
+
+ if network == '3d_lowres':
+ print("predicting segmentations for the next stage of the cascade")
+ predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/run/run_training_DP.py b/nnunet/run/run_training_DP.py
new file mode 100644
index 0000000000000000000000000000000000000000..922afb5c70b6cb656ad781af61b4fe78d1bfab99
--- /dev/null
+++ b/nnunet/run/run_training_DP.py
@@ -0,0 +1,194 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import argparse
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.run.default_configuration import get_default_configuration
+from nnunet.paths import default_plans_identifier
+from nnunet.run.load_pretrained_weights import load_pretrained_weights
+from nnunet.training.cascade_stuff.predict_next_stage import predict_next_stage
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerCascadeFullRes import nnUNetTrainerCascadeFullRes
+from nnunet.utilities.task_name_id_conversion import convert_id_to_task_name
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("network")
+ parser.add_argument("network_trainer")
+ parser.add_argument("task", help="can be task name or task id")
+ parser.add_argument("fold", help='0, 1, ..., 5 or \'all\'')
+ parser.add_argument("-val", "--validation_only", help="use this if you want to only run the validation",
+ action="store_true")
+ parser.add_argument("-c", "--continue_training", help="use this if you want to continue a training",
+ action="store_true")
+ parser.add_argument("-p", help="plans identifier. Only change this if you created a custom experiment planner",
+ default=default_plans_identifier, required=False)
+ parser.add_argument("--use_compressed_data", default=False, action="store_true",
+ help="If you set use_compressed_data, the training cases will not be decompressed. Reading compressed data "
+ "is much more CPU and RAM intensive and should only be used if you know what you are "
+ "doing", required=False)
+ parser.add_argument("--deterministic",
+ help="Makes training deterministic, but reduces training speed substantially. I (Fabian) think "
+ "this is not necessary. Deterministic training will make you overfit to some random seed. "
+ "Don't use that.",
+ required=False, default=False, action="store_true")
+ parser.add_argument("-gpus", help="number of gpus", required=True,type=int)
+ parser.add_argument("--dbs", required=False, default=False, action="store_true", help="distribute batch size. If "
+ "True then whatever "
+ "batch_size is in plans will "
+ "be distributed over DDP "
+ "models, if False then each "
+ "model will have batch_size "
+ "for a total of "
+ "GPUs*batch_size")
+ parser.add_argument("--npz", required=False, default=False, action="store_true", help="if set then nnUNet will "
+ "export npz files of "
+ "predicted segmentations "
+ "in the vlaidation as well. "
+ "This is needed to run the "
+ "ensembling step so unless "
+ "you are developing nnUNet "
+ "you should enable this")
+ parser.add_argument("--valbest", required=False, default=False, action="store_true", help="")
+ parser.add_argument("--find_lr", required=False, default=False, action="store_true", help="")
+ parser.add_argument("--fp32", required=False, default=False, action="store_true",
+ help="disable mixed precision training and run old school fp32")
+ parser.add_argument("--val_folder", required=False, default="validation_raw",
+ help="name of the validation folder. No need to use this for most people")
+ parser.add_argument("--disable_saving", required=False, action='store_true',
+ help="If set nnU-Net will not save any parameter files. Useful for development when you are "
+ "only interested in the results and want to save some disk space")
+ parser.add_argument("--disable_postprocessing_on_folds", required=False, action='store_true',
+ help="Running postprocessing on each fold only makes sense when developing with nnU-Net and "
+ "closely observing the model performance on specific configurations. You do not need it "
+ "when applying nnU-Net because the postprocessing for this will be determined only once "
+ "all five folds have been trained and nnUNet_find_best_configuration is called. Usually "
+ "running postprocessing on each fold is computationally cheap, but some users have "
+ "reported issues with very large images. If your images are large (>600x600x600 voxels) "
+ "you should consider setting this flag.")
+ # parser.add_argument("--interp_order", required=False, default=3, type=int,
+ # help="order of interpolation for segmentations. Testing purpose only. Hands off")
+ # parser.add_argument("--interp_order_z", required=False, default=0, type=int,
+ # help="order of interpolation along z if z is resampled separately. Testing purpose only. "
+ # "Hands off")
+ # parser.add_argument("--force_separate_z", required=False, default="None", type=str,
+ # help="force_separate_z resampling. Can be None, True or False. Testing purpose only. Hands off")
+ parser.add_argument('-pretrained_weights', type=str, required=False, default=None,
+ help='path to nnU-Net checkpoint file to be used as pretrained model (use .model '
+ 'file, for example model_final_checkpoint.model). Will only be used when actually training. '
+ 'Optional. Beta. Use with caution.')
+
+ args = parser.parse_args()
+
+ task = args.task
+ fold = args.fold
+ network = args.network
+ network_trainer = args.network_trainer
+ validation_only = args.validation_only
+ plans_identifier = args.p
+ disable_postprocessing_on_folds = args.disable_postprocessing_on_folds
+
+ use_compressed_data = args.use_compressed_data
+ decompress_data = not use_compressed_data
+
+ deterministic = args.deterministic
+ valbest = args.valbest
+ find_lr = args.find_lr
+ num_gpus = args.gpus
+ fp32 = args.fp32
+ val_folder = args.val_folder
+ # interp_order = args.interp_order
+ # interp_order_z = args.interp_order_z
+ # force_separate_z = args.force_separate_z
+
+ if not task.startswith("Task"):
+ task_id = int(task)
+ task = convert_id_to_task_name(task_id)
+
+ if fold == 'all':
+ pass
+ else:
+ fold = int(fold)
+
+ # if force_separate_z == "None":
+ # force_separate_z = None
+ # elif force_separate_z == "False":
+ # force_separate_z = False
+ # elif force_separate_z == "True":
+ # force_separate_z = True
+ # else:
+ # raise ValueError("force_separate_z must be None, True or False. Given: %s" % force_separate_z)
+
+ plans_file, output_folder_name, dataset_directory, batch_dice, stage, \
+ trainer_class = get_default_configuration(network, task, network_trainer, plans_identifier)
+
+ if trainer_class is None:
+ raise RuntimeError("Could not find trainer class")
+
+ if network == "3d_cascade_fullres":
+ assert issubclass(trainer_class, nnUNetTrainerCascadeFullRes), "If running 3d_cascade_fullres then your " \
+ "trainer class must be derived from " \
+ "nnUNetTrainerCascadeFullRes"
+ else:
+ assert issubclass(trainer_class, nnUNetTrainer), "network_trainer was found but is not derived from " \
+ "nnUNetTrainer"
+
+ trainer = trainer_class(plans_file, fold, output_folder=output_folder_name,
+ dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage,
+ unpack_data=decompress_data, deterministic=deterministic,
+ distribute_batch_size=args.dbs, num_gpus=num_gpus, fp16=not fp32)
+
+ if args.disable_saving:
+ trainer.save_latest_only = False # if false it will not store/overwrite _latest but separate files each
+ trainer.save_intermediate_checkpoints = False # whether or not to save checkpoint_latest
+ trainer.save_best_checkpoint = False # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
+ trainer.save_final_checkpoint = False # whether or not to save the final checkpoint
+
+ trainer.initialize(not validation_only)
+
+ if find_lr:
+ trainer.find_lr()
+ else:
+ if not validation_only:
+ if args.continue_training:
+ # -c was set, continue a previous training and ignore pretrained weights
+ trainer.load_latest_checkpoint()
+ elif (not args.continue_training) and (args.pretrained_weights is not None):
+ # we start a new training. If pretrained_weights are set, use them
+ load_pretrained_weights(trainer.network, args.pretrained_weights)
+ else:
+ # new training without pretraine weights, do nothing
+ pass
+ trainer.run_training()
+ else:
+ if valbest:
+ trainer.load_best_checkpoint(train=False)
+ else:
+ trainer.load_final_checkpoint(train=False)
+
+ trainer.network.eval()
+
+ # predict validation
+ trainer.validate(save_softmax=args.npz, validation_folder_name=val_folder,
+ run_postprocessing_on_folds=not disable_postprocessing_on_folds)
+
+ if network == '3d_lowres':
+ print("predicting segmentations for the next stage of the cascade")
+ predict_next_stage(trainer, join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/nnunet/training/.DS_Store b/nnunet/training/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..b6013867acfad4be9caabbcf799f366313da7a51
Binary files /dev/null and b/nnunet/training/.DS_Store differ
diff --git a/nnunet/training/__init__.py b/nnunet/training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/cascade_stuff/__init__.py b/nnunet/training/cascade_stuff/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/cascade_stuff/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/cascade_stuff/predict_next_stage.py b/nnunet/training/cascade_stuff/predict_next_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c760cd048d5f0b003d7fdd86b457307ef608c24
--- /dev/null
+++ b/nnunet/training/cascade_stuff/predict_next_stage.py
@@ -0,0 +1,135 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from copy import deepcopy
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+import argparse
+from nnunet.preprocessing.preprocessing import resample_data_or_seg
+from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p
+import nnunet
+from nnunet.run.default_configuration import get_default_configuration
+from multiprocessing import Pool
+
+from nnunet.training.model_restore import recursive_find_python_class
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+
+
+def resample_and_save(predicted, target_shape, output_file, force_separate_z=False,
+ interpolation_order=1, interpolation_order_z=0):
+ if isinstance(predicted, str):
+ assert isfile(predicted), "If isinstance(segmentation_softmax, str) then " \
+ "isfile(segmentation_softmax) must be True"
+ del_file = deepcopy(predicted)
+ predicted = np.load(predicted)
+ os.remove(del_file)
+
+ predicted_new_shape = resample_data_or_seg(predicted, target_shape, False, order=interpolation_order,
+ do_separate_z=force_separate_z, order_z=interpolation_order_z)
+ seg_new_shape = predicted_new_shape.argmax(0)
+ np.savez_compressed(output_file, data=seg_new_shape.astype(np.uint8))
+
+
+def predict_next_stage(trainer, stage_to_be_predicted_folder):
+ output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
+ maybe_mkdir_p(output_folder)
+
+ if 'segmentation_export_params' in trainer.plans.keys():
+ force_separate_z = trainer.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = trainer.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = trainer.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+
+ export_pool = Pool(2)
+ results = []
+
+ for pat in trainer.dataset_val.keys():
+ print(pat)
+ data_file = trainer.dataset_val[pat]['data_file']
+ data_preprocessed = np.load(data_file)['data'][:-1]
+
+ predicted_probabilities = trainer.predict_preprocessed_data_return_seg_and_softmax(
+ data_preprocessed, do_mirroring=trainer.data_aug_params["do_mirror"],
+ mirror_axes=trainer.data_aug_params['mirror_axes'], mixed_precision=trainer.fp16)[1]
+
+ data_file_nofolder = data_file.split("/")[-1]
+ data_file_nextstage = join(stage_to_be_predicted_folder, data_file_nofolder)
+ data_nextstage = np.load(data_file_nextstage)['data']
+ target_shp = data_nextstage.shape[1:]
+ output_file = join(output_folder, data_file_nextstage.split("/")[-1][:-4] + "_segFromPrevStage.npz")
+
+ if np.prod(predicted_probabilities.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
+ np.save(output_file[:-4] + ".npy", predicted_probabilities)
+ predicted_probabilities = output_file[:-4] + ".npy"
+
+ results.append(export_pool.starmap_async(resample_and_save, [(predicted_probabilities, target_shp, output_file,
+ force_separate_z, interpolation_order,
+ interpolation_order_z)]))
+
+ _ = [i.get() for i in results]
+ export_pool.close()
+ export_pool.join()
+
+
+if __name__ == "__main__":
+ """
+ RUNNING THIS SCRIPT MANUALLY IS USUALLY NOT NECESSARY. USE THE run_training.py FILE!
+
+ This script is intended for predicting all the low resolution predictions of 3d_lowres for the next stage of the
+ cascade. It needs to run once for each fold so that the segmentation is only generated for the validation set
+ and not on the data the network was trained on. Run it with
+ python predict_next_stage TRAINERCLASS TASK FOLD"""
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("network_trainer")
+ parser.add_argument("task")
+ parser.add_argument("fold", type=int)
+
+ args = parser.parse_args()
+
+ trainerclass = args.network_trainer
+ task = args.task
+ fold = args.fold
+
+ plans_file, folder_with_preprocessed_data, output_folder_name, dataset_directory, batch_dice, stage = \
+ get_default_configuration("3d_lowres", task)
+
+ trainer_class = recursive_find_python_class([join(nnunet.__path__[0], "training", "network_training")],
+ trainerclass,
+ "nnunet.training.network_training")
+
+ if trainer_class is None:
+ raise RuntimeError("Could not find trainer class in nnunet.training.network_training")
+ else:
+ assert issubclass(trainer_class,
+ nnUNetTrainer), "network_trainer was found but is not derived from nnUNetTrainer"
+
+ trainer = trainer_class(plans_file, fold, folder_with_preprocessed_data, output_folder=output_folder_name,
+ dataset_directory=dataset_directory, batch_dice=batch_dice, stage=stage)
+
+ trainer.initialize(False)
+ trainer.load_dataset()
+ trainer.do_split()
+ trainer.load_best_checkpoint(train=False)
+
+ stage_to_be_predicted_folder = join(dataset_directory, trainer.plans['data_identifier'] + "_stage%d" % 1)
+ output_folder = join(pardir(trainer.output_folder), "pred_next_stage")
+ maybe_mkdir_p(output_folder)
+
+ predict_next_stage(trainer, stage_to_be_predicted_folder)
diff --git a/nnunet/training/data_augmentation/__init__.py b/nnunet/training/data_augmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/data_augmentation/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/data_augmentation/custom_transforms.py b/nnunet/training/data_augmentation/custom_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..b94e64a72743fe51ce78816ef8fb0bc0b6433168
--- /dev/null
+++ b/nnunet/training/data_augmentation/custom_transforms.py
@@ -0,0 +1,123 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+
+
+class RemoveKeyTransform(AbstractTransform):
+ def __init__(self, key_to_remove):
+ self.key_to_remove = key_to_remove
+
+ def __call__(self, **data_dict):
+ _ = data_dict.pop(self.key_to_remove, None)
+ return data_dict
+
+
+class MaskTransform(AbstractTransform):
+ def __init__(self, dct_for_where_it_was_used, mask_idx_in_seg=1, set_outside_to=0, data_key="data", seg_key="seg"):
+ """
+ data[mask < 0] = 0
+ Sets everything outside the mask to 0. CAREFUL! outside is defined as < 0, not =0 (in the Mask)!!!
+
+ :param dct_for_where_it_was_used:
+ :param mask_idx_in_seg:
+ :param set_outside_to:
+ :param data_key:
+ :param seg_key:
+ """
+ self.dct_for_where_it_was_used = dct_for_where_it_was_used
+ self.seg_key = seg_key
+ self.data_key = data_key
+ self.set_outside_to = set_outside_to
+ self.mask_idx_in_seg = mask_idx_in_seg
+
+ def __call__(self, **data_dict):
+ seg = data_dict.get(self.seg_key)
+ if seg is None or seg.shape[1] < self.mask_idx_in_seg:
+ raise Warning("mask not found, seg may be missing or seg[:, mask_idx_in_seg] may not exist")
+ data = data_dict.get(self.data_key)
+ for b in range(data.shape[0]):
+ mask = seg[b, self.mask_idx_in_seg]
+ for c in range(data.shape[1]):
+ if self.dct_for_where_it_was_used[c]:
+ data[b, c][mask < 0] = self.set_outside_to
+ data_dict[self.data_key] = data
+ return data_dict
+
+
+def convert_3d_to_2d_generator(data_dict):
+ shp = data_dict['data'].shape
+ data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
+ data_dict['orig_shape_data'] = shp
+ shp = data_dict['seg'].shape
+ data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1] * shp[2], shp[3], shp[4]))
+ data_dict['orig_shape_seg'] = shp
+ return data_dict
+
+
+def convert_2d_to_3d_generator(data_dict):
+ shp = data_dict['orig_shape_data']
+ current_shape = data_dict['data'].shape
+ data_dict['data'] = data_dict['data'].reshape((shp[0], shp[1], shp[2], current_shape[-2], current_shape[-1]))
+ shp = data_dict['orig_shape_seg']
+ current_shape_seg = data_dict['seg'].shape
+ data_dict['seg'] = data_dict['seg'].reshape((shp[0], shp[1], shp[2], current_shape_seg[-2], current_shape_seg[-1]))
+ return data_dict
+
+
+class Convert3DTo2DTransform(AbstractTransform):
+ def __init__(self):
+ pass
+
+ def __call__(self, **data_dict):
+ return convert_3d_to_2d_generator(data_dict)
+
+
+class Convert2DTo3DTransform(AbstractTransform):
+ def __init__(self):
+ pass
+
+ def __call__(self, **data_dict):
+ return convert_2d_to_3d_generator(data_dict)
+
+
+class ConvertSegmentationToRegionsTransform(AbstractTransform):
+ def __init__(self, regions: dict, seg_key: str = "seg", output_key: str = "seg", seg_channel: int = 0):
+ """
+ regions are tuple of tuples where each inner tuple holds the class indices that are merged into one region, example:
+ regions= ((1, 2), (2, )) will result in 2 regions: one covering the region of labels 1&2 and the other just 2
+ :param regions:
+ :param seg_key:
+ :param output_key:
+ """
+ self.seg_channel = seg_channel
+ self.output_key = output_key
+ self.seg_key = seg_key
+ self.regions = regions
+
+ def __call__(self, **data_dict):
+ seg = data_dict.get(self.seg_key)
+ num_regions = len(self.regions)
+ if seg is not None:
+ seg_shp = seg.shape
+ output_shape = list(seg_shp)
+ output_shape[1] = num_regions
+ region_output = np.zeros(output_shape, dtype=seg.dtype)
+ for b in range(seg_shp[0]):
+ for r, k in enumerate(self.regions.keys()):
+ for l in self.regions[k]:
+ region_output[b, r][seg[b, self.seg_channel] == l] = 1
+ data_dict[self.output_key] = region_output
+ return data_dict
diff --git a/nnunet/training/data_augmentation/data_augmentation_insaneDA.py b/nnunet/training/data_augmentation/data_augmentation_insaneDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..e463167a51ac01e248a31e4cca0ae54ea206ee6c
--- /dev/null
+++ b/nnunet/training/data_augmentation/data_augmentation_insaneDA.py
@@ -0,0 +1,183 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
+ SegChannelSelectionTransform
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, BrightnessTransform, GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
+ MaskTransform, ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
+from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
+from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, \
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform
+
+try:
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+except ImportError as ie:
+ NonDetMultiThreadedAugmenter = None
+
+
+def get_insaneDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
+ border_val_seg=-1,
+ seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
+ soft_ds=False,
+ classes=None, pin_memory=True, regions=None):
+ assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
+
+ tr_transforms = []
+
+ if params.get("selected_data_channels") is not None:
+ tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+
+ if params.get("selected_seg_channels") is not None:
+ tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
+ if params.get("dummy_2D") is not None and params.get("dummy_2D"):
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
+ alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
+ do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
+ angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
+ border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg,
+ order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
+ p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
+ independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"),
+ p_independent_scale_per_axis=params.get("p_independent_scale_per_axis")
+ ))
+
+ if params.get("dummy_2D"):
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
+ # channel gets in the way
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15))
+ tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=0.15)) # inverted gamma
+
+ if params.get("do_additive_brightness"):
+ tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
+ params.get("additive_brightness_sigma"),
+ True, p_per_sample=params.get("additive_brightness_p_per_sample"),
+ p_per_channel=params.get("additive_brightness_p_per_channel")))
+
+ if params.get("do_gamma"):
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=params["p_gamma"]))
+
+ if params.get("do_mirror") or params.get("mirror"):
+ tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
+
+ if params.get("mask_was_used_for_normalization") is not None:
+ mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
+ tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+ if params.get("cascade_do_cascade_augmentations") and not None and params.get(
+ "cascade_do_cascade_augmentations"):
+ if params.get("cascade_random_binary_transform_p") > 0:
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ p_per_sample=params.get("cascade_random_binary_transform_p"),
+ key="data",
+ strel_size=params.get("cascade_random_binary_transform_size")))
+ if params.get("cascade_remove_conn_comp_p") > 0:
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ key="data",
+ p_per_sample=params.get("cascade_remove_conn_comp_p"),
+ fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
+ dont_do_if_covers_more_than_X_percent=params.get(
+ "cascade_remove_conn_comp_fill_with_other_class_p")))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+
+ batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_train, pin_memory=pin_memory)
+
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+ if params.get("selected_data_channels") is not None:
+ val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+ if params.get("selected_seg_channels") is not None:
+ val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+
+ batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_val, pin_memory=pin_memory)
+ return batchgenerator_train, batchgenerator_val
\ No newline at end of file
diff --git a/nnunet/training/data_augmentation/data_augmentation_insaneDA2.py b/nnunet/training/data_augmentation/data_augmentation_insaneDA2.py
new file mode 100644
index 0000000000000000000000000000000000000000..69a06b83d500299697d886fb215dc2dbb5d06d10
--- /dev/null
+++ b/nnunet/training/data_augmentation/data_augmentation_insaneDA2.py
@@ -0,0 +1,188 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
+ SegChannelSelectionTransform
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, BrightnessTransform
+from batchgenerators.transforms.color_transforms import GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import MirrorTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform_2
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+
+from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
+ MaskTransform, ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
+from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
+from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, \
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform
+
+try:
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+except ImportError as ie:
+ NonDetMultiThreadedAugmenter = None
+
+
+def get_insaneDA_augmentation2(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
+ border_val_seg=-1,
+ seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
+ soft_ds=False,
+ classes=None, pin_memory=True, regions=None):
+ assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
+
+ tr_transforms = []
+
+ if params.get("selected_data_channels") is not None:
+ tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+
+ if params.get("selected_seg_channels") is not None:
+ tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
+ if params.get("dummy_2D") is not None and params.get("dummy_2D"):
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform_2(
+ patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
+ deformation_scale=params.get("eldef_deformation_scale"),
+ do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
+ angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
+ border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg,
+ order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
+ p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
+ independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis"),
+ p_independent_scale_per_axis=params.get("p_independent_scale_per_axis")
+ ))
+
+ if params.get("dummy_2D"):
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
+ # channel gets in the way
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.15))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.5), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.70, 1.3), p_per_sample=0.15))
+ tr_transforms.append(ContrastAugmentationTransform(contrast_range=(0.65, 1.5), p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=0.15)) # inverted gamma
+
+ if params.get("do_additive_brightness"):
+ tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
+ params.get("additive_brightness_sigma"),
+ True, p_per_sample=params.get("additive_brightness_p_per_sample"),
+ p_per_channel=params.get("additive_brightness_p_per_channel")))
+
+ if params.get("do_gamma"):
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=params["p_gamma"]))
+
+ if params.get("do_mirror") or params.get("mirror"):
+ tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
+
+ if params.get("mask_was_used_for_normalization") is not None:
+ mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
+ tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+ if params.get("cascade_do_cascade_augmentations") and not None and params.get(
+ "cascade_do_cascade_augmentations"):
+ if params.get("cascade_random_binary_transform_p") > 0:
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ p_per_sample=params.get("cascade_random_binary_transform_p"),
+ key="data",
+ strel_size=params.get("cascade_random_binary_transform_size")))
+ if params.get("cascade_remove_conn_comp_p") > 0:
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ key="data",
+ p_per_sample=params.get("cascade_remove_conn_comp_p"),
+ fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
+ dont_do_if_covers_more_than_X_percent=params.get(
+ "cascade_remove_conn_comp_fill_with_other_class_p")))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+
+ batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_train, pin_memory=pin_memory)
+ #batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
+
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+ if params.get("selected_data_channels") is not None:
+ val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+ if params.get("selected_seg_channels") is not None:
+ val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+
+ batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_val, pin_memory=pin_memory)
+ return batchgenerator_train, batchgenerator_val
+
diff --git a/nnunet/training/data_augmentation/data_augmentation_moreDA.py b/nnunet/training/data_augmentation/data_augmentation_moreDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a7a6cb2acb06cf3411d1cff9310f2e3a9bc838
--- /dev/null
+++ b/nnunet/training/data_augmentation/data_augmentation_moreDA.py
@@ -0,0 +1,210 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
+ SegChannelSelectionTransform
+from batchgenerators.transforms.color_transforms import BrightnessMultiplicativeTransform, \
+ ContrastAugmentationTransform, BrightnessTransform
+from batchgenerators.transforms.color_transforms import GammaTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+
+from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
+ MaskTransform, ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
+from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
+from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, \
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform
+
+try:
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+except ImportError as ie:
+ NonDetMultiThreadedAugmenter = None
+
+
+def get_moreDA_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
+ border_val_seg=-1,
+ seeds_train=None, seeds_val=None, order_seg=1, order_data=3, deep_supervision_scales=None,
+ soft_ds=False,
+ classes=None, pin_memory=True, regions=None,
+ use_nondetMultiThreadedAugmenter: bool = False):
+ assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
+
+ tr_transforms = []
+
+ if params.get("selected_data_channels") is not None:
+ tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+
+ if params.get("selected_seg_channels") is not None:
+ tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
+ if params.get("dummy_2D") is not None and params.get("dummy_2D"):
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+ ignore_axes = None
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None,
+ do_elastic_deform=params.get("do_elastic"), alpha=params.get("elastic_deform_alpha"),
+ sigma=params.get("elastic_deform_sigma"),
+ do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
+ angle_z=params.get("rotation_z"), p_rot_per_axis=params.get("rotation_p_per_axis"),
+ do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
+ border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=order_data,
+ border_mode_seg="constant", border_cval_seg=border_val_seg,
+ order_seg=order_seg, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
+ p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
+ independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
+ ))
+
+ if params.get("dummy_2D"):
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ # we need to put the color augmentations after the dummy 2d part (if applicable). Otherwise the overloaded color
+ # channel gets in the way
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+ tr_transforms.append(GaussianBlurTransform((0.5, 1.), different_sigma_per_channel=True, p_per_sample=0.2,
+ p_per_channel=0.5))
+ tr_transforms.append(BrightnessMultiplicativeTransform(multiplier_range=(0.75, 1.25), p_per_sample=0.15))
+
+ if params.get("do_additive_brightness"):
+ tr_transforms.append(BrightnessTransform(params.get("additive_brightness_mu"),
+ params.get("additive_brightness_sigma"),
+ True, p_per_sample=params.get("additive_brightness_p_per_sample"),
+ p_per_channel=params.get("additive_brightness_p_per_channel")))
+
+ tr_transforms.append(ContrastAugmentationTransform(p_per_sample=0.15))
+ tr_transforms.append(SimulateLowResolutionTransform(zoom_range=(0.5, 1), per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0, order_upsample=3, p_per_sample=0.25,
+ ignore_axes=ignore_axes))
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), True, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=0.1)) # inverted gamma
+
+ if params.get("do_gamma"):
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=params["p_gamma"]))
+
+ if params.get("do_mirror") or params.get("mirror"):
+ tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
+
+ if params.get("mask_was_used_for_normalization") is not None:
+ mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
+ tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+ if params.get("cascade_do_cascade_augmentations") is not None and params.get(
+ "cascade_do_cascade_augmentations"):
+ if params.get("cascade_random_binary_transform_p") > 0:
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ p_per_sample=params.get("cascade_random_binary_transform_p"),
+ key="data",
+ strel_size=params.get("cascade_random_binary_transform_size"),
+ p_per_label=params.get("cascade_random_binary_transform_p_per_label")))
+ if params.get("cascade_remove_conn_comp_p") > 0:
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ key="data",
+ p_per_sample=params.get("cascade_remove_conn_comp_p"),
+ fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
+ dont_do_if_covers_more_than_X_percent=params.get(
+ "cascade_remove_conn_comp_fill_with_other_class_p")))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ tr_transforms = Compose(tr_transforms)
+
+ if use_nondetMultiThreadedAugmenter:
+ if NonDetMultiThreadedAugmenter is None:
+ raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
+ batchgenerator_train = NonDetMultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"), seeds=seeds_train,
+ pin_memory=pin_memory)
+ else:
+ batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_train, pin_memory=pin_memory)
+ # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
+ # import IPython;IPython.embed()
+
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+ if params.get("selected_data_channels") is not None:
+ val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+ if params.get("selected_seg_channels") is not None:
+ val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+
+ if use_nondetMultiThreadedAugmenter:
+ if NonDetMultiThreadedAugmenter is None:
+ raise RuntimeError('NonDetMultiThreadedAugmenter is not yet available')
+ batchgenerator_val = NonDetMultiThreadedAugmenter(dataloader_val, val_transforms,
+ max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_val, pin_memory=pin_memory)
+ else:
+ batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms,
+ max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"),
+ seeds=seeds_val, pin_memory=pin_memory)
+ # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
+
+ return batchgenerator_train, batchgenerator_val
+
diff --git a/nnunet/training/data_augmentation/data_augmentation_noDA.py b/nnunet/training/data_augmentation/data_augmentation_noDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5fe2fc6db782b5f436a86b38c9d5376d02d8973
--- /dev/null
+++ b/nnunet/training/data_augmentation/data_augmentation_noDA.py
@@ -0,0 +1,98 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
+ SegChannelSelectionTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+
+from nnunet.training.data_augmentation.custom_transforms import ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params
+from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform3, DownsampleSegForDSTransform2
+
+try:
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+except ImportError as ie:
+ NonDetMultiThreadedAugmenter = None
+
+
+def get_no_augmentation(dataloader_train, dataloader_val, params=default_3D_augmentation_params,
+ deep_supervision_scales=None, soft_ds=False,
+ classes=None, pin_memory=True, regions=None):
+ """
+ use this instead of get_default_augmentation (drop in replacement) to turn off all data augmentation
+ """
+ tr_transforms = []
+
+ if params.get("selected_data_channels") is not None:
+ tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+
+ if params.get("selected_seg_channels") is not None:
+ tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ tr_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ tr_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+
+ tr_transforms = Compose(tr_transforms)
+
+ batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"),
+ seeds=range(params.get('num_threads')), pin_memory=pin_memory)
+ batchgenerator_train.restart()
+
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+ if params.get("selected_data_channels") is not None:
+ val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+ if params.get("selected_seg_channels") is not None:
+ val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ if deep_supervision_scales is not None:
+ if soft_ds:
+ assert classes is not None
+ val_transforms.append(DownsampleSegForDSTransform3(deep_supervision_scales, 'target', 'target', classes))
+ else:
+ val_transforms.append(DownsampleSegForDSTransform2(deep_supervision_scales, 0, input_key='target',
+ output_key='target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+
+ batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"),
+ seeds=range(max(params.get('num_threads') // 2, 1)),
+ pin_memory=pin_memory)
+ batchgenerator_val.restart()
+ return batchgenerator_train, batchgenerator_val
+
diff --git a/nnunet/training/data_augmentation/default_data_augmentation.py b/nnunet/training/data_augmentation/default_data_augmentation.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a2054d058d2a489eb2c0aa57f6ca4d4efb0eff6
--- /dev/null
+++ b/nnunet/training/data_augmentation/default_data_augmentation.py
@@ -0,0 +1,257 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from copy import deepcopy
+
+import numpy as np
+from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import Compose
+from batchgenerators.transforms.channel_selection_transforms import DataChannelSelectionTransform, \
+ SegChannelSelectionTransform
+from batchgenerators.transforms.color_transforms import GammaTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform, MirrorTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor
+
+from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
+ MaskTransform, ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, \
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform
+
+try:
+ from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+except ImportError as ie:
+ NonDetMultiThreadedAugmenter = None
+
+
+default_3D_augmentation_params = {
+ "selected_data_channels": None,
+ "selected_seg_channels": None,
+
+ "do_elastic": True,
+ "elastic_deform_alpha": (0., 900.),
+ "elastic_deform_sigma": (9., 13.),
+ "p_eldef": 0.2,
+
+ "do_scaling": True,
+ "scale_range": (0.85, 1.25),
+ "independent_scale_factor_for_each_axis": False,
+ "p_independent_scale_per_axis": 1,
+ "p_scale": 0.2,
+
+ "do_rotation": True,
+ "rotation_x": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ "rotation_y": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ "rotation_z": (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi),
+ "rotation_p_per_axis": 1,
+ "p_rot": 0.2,
+
+ "random_crop": False,
+ "random_crop_dist_to_border": None,
+
+ "do_gamma": True,
+ "gamma_retain_stats": True,
+ "gamma_range": (0.7, 1.5),
+ "p_gamma": 0.3,
+
+ "do_mirror": True,
+ "mirror_axes": (0, 1, 2),
+
+ "dummy_2D": False,
+ "mask_was_used_for_normalization": None,
+ "border_mode_data": "constant",
+
+ "all_segmentation_labels": None, # used for cascade
+ "move_last_seg_chanel_to_data": False, # used for cascade
+ "cascade_do_cascade_augmentations": False, # used for cascade
+ "cascade_random_binary_transform_p": 0.4,
+ "cascade_random_binary_transform_p_per_label": 1,
+ "cascade_random_binary_transform_size": (1, 8),
+ "cascade_remove_conn_comp_p": 0.2,
+ "cascade_remove_conn_comp_max_size_percent_threshold": 0.15,
+ "cascade_remove_conn_comp_fill_with_other_class_p": 0.0,
+
+ "do_additive_brightness": False,
+ "additive_brightness_p_per_sample": 0.15,
+ "additive_brightness_p_per_channel": 0.5,
+ "additive_brightness_mu": 0.0,
+ "additive_brightness_sigma": 0.1,
+
+ "num_threads": 12 if 'nnUNet_n_proc_DA' not in os.environ else int(os.environ['nnUNet_n_proc_DA']),
+ "num_cached_per_thread": 1,
+}
+
+default_2D_augmentation_params = deepcopy(default_3D_augmentation_params)
+
+default_2D_augmentation_params["elastic_deform_alpha"] = (0., 200.)
+default_2D_augmentation_params["elastic_deform_sigma"] = (9., 13.)
+default_2D_augmentation_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+default_2D_augmentation_params["rotation_y"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
+default_2D_augmentation_params["rotation_z"] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
+
+# sometimes you have 3d data and a 3d net but cannot augment them properly in 3d due to anisotropy (which is currently
+# not supported in batchgenerators). In that case you can 'cheat' and transfer your 3d data into 2d data and
+# transform them back after augmentation
+default_2D_augmentation_params["dummy_2D"] = False
+default_2D_augmentation_params["mirror_axes"] = (0, 1) # this can be (0, 1, 2) if dummy_2D=True
+
+
+def get_patch_size(final_patch_size, rot_x, rot_y, rot_z, scale_range):
+ if isinstance(rot_x, (tuple, list)):
+ rot_x = max(np.abs(rot_x))
+ if isinstance(rot_y, (tuple, list)):
+ rot_y = max(np.abs(rot_y))
+ if isinstance(rot_z, (tuple, list)):
+ rot_z = max(np.abs(rot_z))
+ rot_x = min(90 / 360 * 2. * np.pi, rot_x)
+ rot_y = min(90 / 360 * 2. * np.pi, rot_y)
+ rot_z = min(90 / 360 * 2. * np.pi, rot_z)
+ from batchgenerators.augmentations.utils import rotate_coords_3d, rotate_coords_2d
+ coords = np.array(final_patch_size)
+ final_shape = np.copy(coords)
+ if len(coords) == 3:
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, rot_x, 0, 0)), final_shape)), 0)
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, rot_y, 0)), final_shape)), 0)
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_3d(coords, 0, 0, rot_z)), final_shape)), 0)
+ elif len(coords) == 2:
+ final_shape = np.max(np.vstack((np.abs(rotate_coords_2d(coords, rot_x)), final_shape)), 0)
+ final_shape /= min(scale_range)
+ return final_shape.astype(int)
+
+
+def get_default_augmentation(dataloader_train, dataloader_val, patch_size, params=default_3D_augmentation_params,
+ border_val_seg=-1, pin_memory=True,
+ seeds_train=None, seeds_val=None, regions=None):
+ assert params.get('mirror') is None, "old version of params, use new keyword do_mirror"
+ tr_transforms = []
+
+ if params.get("selected_data_channels") is not None:
+ tr_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+
+ if params.get("selected_seg_channels") is not None:
+ tr_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ # don't do color augmentations while in 2d mode with 3d data because the color channel is overloaded!!
+ if params.get("dummy_2D") is not None and params.get("dummy_2D"):
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = patch_size[1:]
+ else:
+ patch_size_spatial = patch_size
+
+ tr_transforms.append(SpatialTransform(
+ patch_size_spatial, patch_center_dist_from_border=None, do_elastic_deform=params.get("do_elastic"),
+ alpha=params.get("elastic_deform_alpha"), sigma=params.get("elastic_deform_sigma"),
+ do_rotation=params.get("do_rotation"), angle_x=params.get("rotation_x"), angle_y=params.get("rotation_y"),
+ angle_z=params.get("rotation_z"), do_scale=params.get("do_scaling"), scale=params.get("scale_range"),
+ border_mode_data=params.get("border_mode_data"), border_cval_data=0, order_data=3, border_mode_seg="constant",
+ border_cval_seg=border_val_seg,
+ order_seg=1, random_crop=params.get("random_crop"), p_el_per_sample=params.get("p_eldef"),
+ p_scale_per_sample=params.get("p_scale"), p_rot_per_sample=params.get("p_rot"),
+ independent_scale_for_each_axis=params.get("independent_scale_factor_for_each_axis")
+ ))
+ if params.get("dummy_2D") is not None and params.get("dummy_2D"):
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ if params.get("do_gamma"):
+ tr_transforms.append(
+ GammaTransform(params.get("gamma_range"), False, True, retain_stats=params.get("gamma_retain_stats"),
+ p_per_sample=params["p_gamma"]))
+
+ if params.get("do_mirror"):
+ tr_transforms.append(MirrorTransform(params.get("mirror_axes")))
+
+ if params.get("mask_was_used_for_normalization") is not None:
+ mask_was_used_for_normalization = params.get("mask_was_used_for_normalization")
+ tr_transforms.append(MaskTransform(mask_was_used_for_normalization, mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ tr_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+ if params.get("cascade_do_cascade_augmentations") and not None and params.get(
+ "cascade_do_cascade_augmentations"):
+ tr_transforms.append(ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ p_per_sample=params.get("cascade_random_binary_transform_p"),
+ key="data",
+ strel_size=params.get("cascade_random_binary_transform_size")))
+ tr_transforms.append(RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(params.get("all_segmentation_labels")), 0)),
+ key="data",
+ p_per_sample=params.get("cascade_remove_conn_comp_p"),
+ fill_with_other_class_p=params.get("cascade_remove_conn_comp_max_size_percent_threshold"),
+ dont_do_if_covers_more_than_X_percent=params.get("cascade_remove_conn_comp_fill_with_other_class_p")))
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+
+ tr_transforms = Compose(tr_transforms)
+ # from batchgenerators.dataloading import SingleThreadedAugmenter
+ # batchgenerator_train = SingleThreadedAugmenter(dataloader_train, tr_transforms)
+ # import IPython;IPython.embed()
+
+ batchgenerator_train = MultiThreadedAugmenter(dataloader_train, tr_transforms, params.get('num_threads'),
+ params.get("num_cached_per_thread"), seeds=seeds_train,
+ pin_memory=pin_memory)
+
+ val_transforms = []
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+ if params.get("selected_data_channels") is not None:
+ val_transforms.append(DataChannelSelectionTransform(params.get("selected_data_channels")))
+ if params.get("selected_seg_channels") is not None:
+ val_transforms.append(SegChannelSelectionTransform(params.get("selected_seg_channels")))
+
+ if params.get("move_last_seg_chanel_to_data") is not None and params.get("move_last_seg_chanel_to_data"):
+ val_transforms.append(MoveSegAsOneHotToData(1, params.get("all_segmentation_labels"), 'seg', 'data'))
+
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(regions, 'target', 'target'))
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ val_transforms = Compose(val_transforms)
+
+ # batchgenerator_val = SingleThreadedAugmenter(dataloader_val, val_transforms)
+ batchgenerator_val = MultiThreadedAugmenter(dataloader_val, val_transforms, max(params.get('num_threads') // 2, 1),
+ params.get("num_cached_per_thread"), seeds=seeds_val,
+ pin_memory=pin_memory)
+ return batchgenerator_train, batchgenerator_val
+
+
+if __name__ == "__main__":
+ from nnunet.training.dataloading.dataset_loading import DataLoader3D, load_dataset
+ from nnunet.paths import preprocessing_output_dir
+ import os
+ import pickle
+
+ t = "Task002_Heart"
+ p = os.path.join(preprocessing_output_dir, t)
+ dataset = load_dataset(p, 0)
+ with open(os.path.join(p, "plans.pkl"), 'rb') as f:
+ plans = pickle.load(f)
+
+ basic_patch_size = get_patch_size(np.array(plans['stage_properties'][0].patch_size),
+ default_3D_augmentation_params['rotation_x'],
+ default_3D_augmentation_params['rotation_y'],
+ default_3D_augmentation_params['rotation_z'],
+ default_3D_augmentation_params['scale_range'])
+
+ dl = DataLoader3D(dataset, basic_patch_size, np.array(plans['stage_properties'][0].patch_size).astype(int), 1)
+ tr, val = get_default_augmentation(dl, dl, np.array(plans['stage_properties'][0].patch_size).astype(int))
diff --git a/nnunet/training/data_augmentation/downsampling.py b/nnunet/training/data_augmentation/downsampling.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0c6d153588bd283198a2168f5f201d389992b1b
--- /dev/null
+++ b/nnunet/training/data_augmentation/downsampling.py
@@ -0,0 +1,104 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from batchgenerators.augmentations.utils import convert_seg_image_to_one_hot_encoding_batched, resize_segmentation
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+from torch.nn.functional import avg_pool2d, avg_pool3d
+import numpy as np
+
+
+class DownsampleSegForDSTransform3(AbstractTransform):
+ '''
+ returns one hot encodings of the segmentation maps if downsampling has occured (no one hot for highest resolution)
+ downsampled segmentations are smooth, not 0/1
+
+ returns torch tensors, not numpy arrays!
+
+ always uses seg channel 0!!
+
+ you should always give classes! Otherwise weird stuff may happen
+ '''
+ def __init__(self, ds_scales=(1, 0.5, 0.25), input_key="seg", output_key="seg", classes=None):
+ self.classes = classes
+ self.output_key = output_key
+ self.input_key = input_key
+ self.ds_scales = ds_scales
+
+ def __call__(self, **data_dict):
+ data_dict[self.output_key] = downsample_seg_for_ds_transform3(data_dict[self.input_key][:, 0], self.ds_scales, self.classes)
+ return data_dict
+
+
+def downsample_seg_for_ds_transform3(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), classes=None):
+ output = []
+ one_hot = torch.from_numpy(convert_seg_image_to_one_hot_encoding_batched(seg, classes)) # b, c,
+
+ for s in ds_scales:
+ if all([i == 1 for i in s]):
+ output.append(torch.from_numpy(seg))
+ else:
+ kernel_size = tuple(int(1 / i) for i in s)
+ stride = kernel_size
+ pad = tuple((i-1) // 2 for i in kernel_size)
+
+ if len(s) == 2:
+ pool_op = avg_pool2d
+ elif len(s) == 3:
+ pool_op = avg_pool3d
+ else:
+ raise RuntimeError()
+
+ pooled = pool_op(one_hot, kernel_size, stride, pad, count_include_pad=False, ceil_mode=False)
+
+ output.append(pooled)
+ return output
+
+
+class DownsampleSegForDSTransform2(AbstractTransform):
+ '''
+ data_dict['output_key'] will be a list of segmentations scaled according to ds_scales
+ '''
+ def __init__(self, ds_scales=(1, 0.5, 0.25), order=0, input_key="seg", output_key="seg", axes=None):
+ self.axes = axes
+ self.output_key = output_key
+ self.input_key = input_key
+ self.order = order
+ self.ds_scales = ds_scales
+
+ def __call__(self, **data_dict):
+ data_dict[self.output_key] = downsample_seg_for_ds_transform2(data_dict[self.input_key], self.ds_scales,
+ self.order, self.axes)
+ return data_dict
+
+
+def downsample_seg_for_ds_transform2(seg, ds_scales=((1, 1, 1), (0.5, 0.5, 0.5), (0.25, 0.25, 0.25)), order=0, axes=None):
+ if axes is None:
+ axes = list(range(2, len(seg.shape)))
+ output = []
+ for s in ds_scales:
+ if all([i == 1 for i in s]):
+ output.append(seg)
+ else:
+ new_shape = np.array(seg.shape).astype(float)
+ for i, a in enumerate(axes):
+ new_shape[a] *= s[i]
+ new_shape = np.round(new_shape).astype(int)
+ out_seg = np.zeros(new_shape, dtype=seg.dtype)
+ for b in range(seg.shape[0]):
+ for c in range(seg.shape[1]):
+ out_seg[b, c] = resize_segmentation(seg[b, c], new_shape[2:], order)
+ output.append(out_seg)
+ return output
diff --git a/nnunet/training/data_augmentation/pyramid_augmentations.py b/nnunet/training/data_augmentation/pyramid_augmentations.py
new file mode 100644
index 0000000000000000000000000000000000000000..504db06a5425b16e5f37390d12cc8d375df5ca52
--- /dev/null
+++ b/nnunet/training/data_augmentation/pyramid_augmentations.py
@@ -0,0 +1,189 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from copy import deepcopy
+
+from batchgenerators.transforms.abstract_transforms import AbstractTransform
+from skimage.morphology import label, ball
+from skimage.morphology.binary import binary_erosion, binary_dilation, binary_closing, binary_opening
+import numpy as np
+
+
+class RemoveRandomConnectedComponentFromOneHotEncodingTransform(AbstractTransform):
+ def __init__(self, channel_idx, key="data", p_per_sample=0.2, fill_with_other_class_p=0.25,
+ dont_do_if_covers_more_than_X_percent=0.25, p_per_label=1):
+ """
+ :param dont_do_if_covers_more_than_X_percent: dont_do_if_covers_more_than_X_percent=0.25 is 25%!
+ :param channel_idx: can be list or int
+ :param key:
+ """
+ self.p_per_label = p_per_label
+ self.dont_do_if_covers_more_than_X_percent = dont_do_if_covers_more_than_X_percent
+ self.fill_with_other_class_p = fill_with_other_class_p
+ self.p_per_sample = p_per_sample
+ self.key = key
+ if not isinstance(channel_idx, (list, tuple)):
+ channel_idx = [channel_idx]
+ self.channel_idx = channel_idx
+
+ def __call__(self, **data_dict):
+ data = data_dict.get(self.key)
+ for b in range(data.shape[0]):
+ if np.random.uniform() < self.p_per_sample:
+ for c in self.channel_idx:
+ if np.random.uniform() < self.p_per_label:
+ workon = np.copy(data[b, c])
+ num_voxels = np.prod(workon.shape, dtype=np.uint64)
+ lab, num_comp = label(workon, return_num=True)
+ if num_comp > 0:
+ component_ids = []
+ component_sizes = []
+ for i in range(1, num_comp + 1):
+ component_ids.append(i)
+ component_sizes.append(np.sum(lab == i))
+ component_ids = [i for i, j in zip(component_ids, component_sizes) if j < num_voxels*self.dont_do_if_covers_more_than_X_percent]
+ #_ = component_ids.pop(np.argmax(component_sizes))
+ #else:
+ # component_ids = list(range(1, num_comp + 1))
+ if len(component_ids) > 0:
+ random_component = np.random.choice(component_ids)
+ data[b, c][lab == random_component] = 0
+ if np.random.uniform() < self.fill_with_other_class_p:
+ other_ch = [i for i in self.channel_idx if i != c]
+ if len(other_ch) > 0:
+ other_class = np.random.choice(other_ch)
+ data[b, other_class][lab == random_component] = 1
+ data_dict[self.key] = data
+ return data_dict
+
+
+class MoveSegAsOneHotToData(AbstractTransform):
+ def __init__(self, channel_id, all_seg_labels, key_origin="seg", key_target="data", remove_from_origin=True):
+ self.remove_from_origin = remove_from_origin
+ self.all_seg_labels = all_seg_labels
+ self.key_target = key_target
+ self.key_origin = key_origin
+ self.channel_id = channel_id
+
+ def __call__(self, **data_dict):
+ origin = data_dict.get(self.key_origin)
+ target = data_dict.get(self.key_target)
+ seg = origin[:, self.channel_id:self.channel_id+1]
+ seg_onehot = np.zeros((seg.shape[0], len(self.all_seg_labels), *seg.shape[2:]), dtype=seg.dtype)
+ for i, l in enumerate(self.all_seg_labels):
+ seg_onehot[:, i][seg[:, 0] == l] = 1
+ target = np.concatenate((target, seg_onehot), 1)
+ data_dict[self.key_target] = target
+
+ if self.remove_from_origin:
+ remaining_channels = [i for i in range(origin.shape[1]) if i != self.channel_id]
+ origin = origin[:, remaining_channels]
+ data_dict[self.key_origin] = origin
+ return data_dict
+
+
+class ApplyRandomBinaryOperatorTransform(AbstractTransform):
+ def __init__(self, channel_idx, p_per_sample=0.3, any_of_these=(binary_dilation, binary_erosion, binary_closing,
+ binary_opening),
+ key="data", strel_size=(1, 10), p_per_label=1):
+ self.p_per_label = p_per_label
+ self.strel_size = strel_size
+ self.key = key
+ self.any_of_these = any_of_these
+ self.p_per_sample = p_per_sample
+
+ assert not isinstance(channel_idx, tuple), "bäh"
+
+ if not isinstance(channel_idx, list):
+ channel_idx = [channel_idx]
+ self.channel_idx = channel_idx
+
+ def __call__(self, **data_dict):
+ data = data_dict.get(self.key)
+ for b in range(data.shape[0]):
+ if np.random.uniform() < self.p_per_sample:
+ ch = deepcopy(self.channel_idx)
+ np.random.shuffle(ch)
+ for c in ch:
+ if np.random.uniform() < self.p_per_label:
+ operation = np.random.choice(self.any_of_these)
+ selem = ball(np.random.uniform(*self.strel_size))
+ workon = np.copy(data[b, c]).astype(int)
+ res = operation(workon, selem).astype(workon.dtype)
+ data[b, c] = res
+
+ # if class was added, we need to remove it in ALL other channels to keep one hot encoding
+ # properties
+ # we modify data
+ other_ch = [i for i in ch if i != c]
+ if len(other_ch) > 0:
+ was_added_mask = (res - workon) > 0
+ for oc in other_ch:
+ data[b, oc][was_added_mask] = 0
+ # if class was removed, leave it at background
+ data_dict[self.key] = data
+ return data_dict
+
+
+class ApplyRandomBinaryOperatorTransform2(AbstractTransform):
+ def __init__(self, channel_idx, p_per_sample=0.3, p_per_label=0.3, any_of_these=(binary_dilation, binary_closing),
+ key="data", strel_size=(1, 10)):
+ """
+ 2019_11_22: I have no idea what the purpose of this was...
+
+ the same as above but here we should use only expanding operations. Expansions will replace other labels
+ :param channel_idx: can be list or int
+ :param p_per_sample:
+ :param any_of_these:
+ :param fill_diff_with_other_class:
+ :param key:
+ :param strel_size:
+ """
+ self.strel_size = strel_size
+ self.key = key
+ self.any_of_these = any_of_these
+ self.p_per_sample = p_per_sample
+ self.p_per_label = p_per_label
+
+ assert not isinstance(channel_idx, tuple), "bäh"
+
+ if not isinstance(channel_idx, list):
+ channel_idx = [channel_idx]
+ self.channel_idx = channel_idx
+
+ def __call__(self, **data_dict):
+ data = data_dict.get(self.key)
+ for b in range(data.shape[0]):
+ if np.random.uniform() < self.p_per_sample:
+ ch = deepcopy(self.channel_idx)
+ np.random.shuffle(ch)
+ for c in ch:
+ if np.random.uniform() < self.p_per_label:
+ operation = np.random.choice(self.any_of_these)
+ selem = ball(np.random.uniform(*self.strel_size))
+ workon = np.copy(data[b, c]).astype(int)
+ res = operation(workon, selem).astype(workon.dtype)
+ data[b, c] = res
+
+ # if class was added, we need to remove it in ALL other channels to keep one hot encoding
+ # properties
+ # we modify data
+ other_ch = [i for i in ch if i != c]
+ if len(other_ch) > 0:
+ was_added_mask = (res - workon) > 0
+ for oc in other_ch:
+ data[b, oc][was_added_mask] = 0
+ # if class was removed, leave it at backgound
+ data_dict[self.key] = data
+ return data_dict
diff --git a/nnunet/training/dataloading/__init__.py b/nnunet/training/dataloading/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/dataloading/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/dataloading/dataset_loading.py b/nnunet/training/dataloading/dataset_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..08566ff33cd56e5f1a034d3261f45a56e321f893
--- /dev/null
+++ b/nnunet/training/dataloading/dataset_loading.py
@@ -0,0 +1,607 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from collections import OrderedDict
+import numpy as np
+from multiprocessing import Pool
+
+from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
+
+from nnunet.configuration import default_num_threads
+from nnunet.paths import preprocessing_output_dir
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def get_case_identifiers(folder):
+ case_identifiers = [i[:-4] for i in os.listdir(folder) if i.endswith("npz") and (i.find("segFromPrevStage") == -1)]
+ return case_identifiers
+
+
+def get_case_identifiers_from_raw_folder(folder):
+ case_identifiers = np.unique(
+ [i[:-12] for i in os.listdir(folder) if i.endswith(".nii.gz") and (i.find("segFromPrevStage") == -1)])
+ return case_identifiers
+
+
+def convert_to_npy(args):
+ if not isinstance(args, tuple):
+ key = "data"
+ npz_file = args
+ else:
+ npz_file, key = args
+ if not isfile(npz_file[:-3] + "npy"):
+ a = np.load(npz_file)[key]
+ np.save(npz_file[:-3] + "npy", a)
+
+
+def save_as_npz(args):
+ if not isinstance(args, tuple):
+ key = "data"
+ npy_file = args
+ else:
+ npy_file, key = args
+ d = np.load(npy_file)
+ np.savez_compressed(npy_file[:-3] + "npz", **{key: d})
+
+
+def unpack_dataset(folder, threads=default_num_threads, key="data"):
+ """
+ unpacks all npz files in a folder to npy (whatever you want to have unpacked must be saved unter key)
+ :param folder:
+ :param threads:
+ :param key:
+ :return:
+ """
+ p = Pool(threads)
+ npz_files = subfiles(folder, True, None, ".npz", True)
+ p.map(convert_to_npy, zip(npz_files, [key] * len(npz_files)))
+ p.close()
+ p.join()
+
+
+def pack_dataset(folder, threads=default_num_threads, key="data"):
+ p = Pool(threads)
+ npy_files = subfiles(folder, True, None, ".npy", True)
+ p.map(save_as_npz, zip(npy_files, [key] * len(npy_files)))
+ p.close()
+ p.join()
+
+
+def delete_npy(folder):
+ case_identifiers = get_case_identifiers(folder)
+ npy_files = [join(folder, i + ".npy") for i in case_identifiers]
+ npy_files = [i for i in npy_files if isfile(i)]
+ for n in npy_files:
+ os.remove(n)
+
+
+def load_dataset(folder, num_cases_properties_loading_threshold=1000):
+ # we don't load the actual data but instead return the filename to the np file.
+ print('loading dataset')
+ case_identifiers = get_case_identifiers(folder)
+ case_identifiers.sort()
+ dataset = OrderedDict()
+ for c in case_identifiers:
+ dataset[c] = OrderedDict()
+ dataset[c]['data_file'] = join(folder, "%s.npz" % c)
+
+ # dataset[c]['properties'] = load_pickle(join(folder, "%s.pkl" % c))
+ dataset[c]['properties_file'] = join(folder, "%s.pkl" % c)
+
+ if dataset[c].get('seg_from_prev_stage_file') is not None:
+ dataset[c]['seg_from_prev_stage_file'] = join(folder, "%s_segs.npz" % c)
+
+ if len(case_identifiers) <= num_cases_properties_loading_threshold:
+ print('loading all case properties')
+ for i in dataset.keys():
+ dataset[i]['properties'] = load_pickle(dataset[i]['properties_file'])
+
+ return dataset
+
+
+def crop_2D_image_force_fg(img, crop_size, valid_voxels):
+ """
+ img must be [c, x, y]
+ img[-1] must be the segmentation with segmentation>0 being foreground
+ :param img:
+ :param crop_size:
+ :param valid_voxels: voxels belonging to the selected class
+ :return:
+ """
+ assert len(valid_voxels.shape) == 2
+
+ if type(crop_size) not in (tuple, list):
+ crop_size = [crop_size] * (len(img.shape) - 1)
+ else:
+ assert len(crop_size) == (len(
+ img.shape) - 1), "If you provide a list/tuple as center crop make sure it has the same len as your data has dims (3d)"
+
+ # we need to find the center coords that we can crop to without exceeding the image border
+ lb_x = crop_size[0] // 2
+ ub_x = img.shape[1] - crop_size[0] // 2 - crop_size[0] % 2
+ lb_y = crop_size[1] // 2
+ ub_y = img.shape[2] - crop_size[1] // 2 - crop_size[1] % 2
+
+ if len(valid_voxels) == 0:
+ selected_center_voxel = (np.random.random_integers(lb_x, ub_x),
+ np.random.random_integers(lb_y, ub_y))
+ else:
+ selected_center_voxel = valid_voxels[np.random.choice(valid_voxels.shape[1]), :]
+
+ selected_center_voxel = np.array(selected_center_voxel)
+ for i in range(2):
+ selected_center_voxel[i] = max(crop_size[i] // 2, selected_center_voxel[i])
+ selected_center_voxel[i] = min(img.shape[i + 1] - crop_size[i] // 2 - crop_size[i] % 2,
+ selected_center_voxel[i])
+
+ result = img[:, (selected_center_voxel[0] - crop_size[0] // 2):(
+ selected_center_voxel[0] + crop_size[0] // 2 + crop_size[0] % 2),
+ (selected_center_voxel[1] - crop_size[1] // 2):(
+ selected_center_voxel[1] + crop_size[1] // 2 + crop_size[1] % 2)]
+ return result
+
+
+class DataLoader3D(SlimDataLoaderBase):
+ def __init__(self, data, patch_size, final_patch_size, batch_size, has_prev_stage=False,
+ oversample_foreground_percent=0.0, memmap_mode="r", pad_mode="edge", pad_kwargs_data=None,
+ pad_sides=None):
+ """
+ This is the basic data loader for 3D networks. It uses preprocessed data as produced by my (Fabian) preprocessing.
+ You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there
+ are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while
+ and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz
+ to npy. Don't forget to call delete_npy(folder) after you are done with training?
+ Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool
+ given that I (Fabian) will have to store that permanently on /datasets and my local computer. With this strategy all
+ data is stored in a compressed format (factor 10 smaller) and only unpacked when needed.
+ :param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go)
+ :param patch_size: what patch size will this data loader return? it is common practice to first load larger
+ patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use
+ get_patch_size() from data_augmentation.default_data_augmentation
+ :param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch
+ size that goes into your network. We need this here because we will pad patients in here so that patches at the
+ border of patients are sampled properly
+ :param batch_size:
+ :param num_batches: how many batches will the data loader produce before stopping? None=endless
+ :param seed:
+ :param stage: ignore this (Fabian only)
+ :param random: Sample keys randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset
+ :param oversample_foreground: half the batch will be forced to contain at least some foreground (equal prob for each of the foreground classes)
+ """
+ super(DataLoader3D, self).__init__(data, batch_size, None)
+ if pad_kwargs_data is None:
+ pad_kwargs_data = OrderedDict()
+ self.pad_kwargs_data = pad_kwargs_data
+ self.pad_mode = pad_mode
+ self.oversample_foreground_percent = oversample_foreground_percent
+ self.final_patch_size = final_patch_size
+ self.has_prev_stage = has_prev_stage
+ self.patch_size = patch_size
+ self.list_of_keys = list(self._data.keys())
+ # need_to_pad denotes by how much we need to pad the data so that if we sample a patch of size final_patch_size
+ # (which is what the network will get) these patches will also cover the border of the patients
+ self.need_to_pad = (np.array(patch_size) - np.array(final_patch_size)).astype(int)
+ if pad_sides is not None:
+ if not isinstance(pad_sides, np.ndarray):
+ pad_sides = np.array(pad_sides)
+ self.need_to_pad += pad_sides
+ self.memmap_mode = memmap_mode
+ self.num_channels = None
+ self.pad_sides = pad_sides
+ self.data_shape, self.seg_shape = self.determine_shapes()
+
+ def get_do_oversample(self, batch_idx):
+ return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
+
+ def determine_shapes(self):
+ if self.has_prev_stage:
+ num_seg = 2
+ else:
+ num_seg = 1
+
+ k = list(self._data.keys())[0]
+ if isfile(self._data[k]['data_file'][:-4] + ".npy"):
+ case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode)
+ else:
+ case_all_data = np.load(self._data[k]['data_file'])['data']
+ num_color_channels = case_all_data.shape[0] - 1
+ data_shape = (self.batch_size, num_color_channels, *self.patch_size)
+ seg_shape = (self.batch_size, num_seg, *self.patch_size)
+ return data_shape, seg_shape
+
+ def generate_train_batch(self):
+ selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None)
+ data = np.zeros(self.data_shape, dtype=np.float32)
+ seg = np.zeros(self.seg_shape, dtype=np.float32)
+ case_properties = []
+ for j, i in enumerate(selected_keys):
+ # oversampling foreground will improve stability of model training, especially if many patches are empty
+ # (Lung for example)
+ if self.get_do_oversample(j):
+ force_fg = True
+ else:
+ force_fg = False
+
+ if 'properties' in self._data[i].keys():
+ properties = self._data[i]['properties']
+ else:
+ properties = load_pickle(self._data[i]['properties_file'])
+ case_properties.append(properties)
+
+ # cases are stored as npz, but we require unpack_dataset to be run. This will decompress them into npy
+ # which is much faster to access
+ if isfile(self._data[i]['data_file'][:-4] + ".npy"):
+ case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)
+ else:
+ case_all_data = np.load(self._data[i]['data_file'])['data']
+
+ # If we are doing the cascade then we will also need to load the segmentation of the previous stage and
+ # concatenate it. Here it will be concatenates to the segmentation because the augmentations need to be
+ # applied to it in segmentation mode. Later in the data augmentation we move it from the segmentations to
+ # the last channel of the data
+ if self.has_prev_stage:
+ if isfile(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy"):
+ segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'][:-4] + ".npy",
+ mmap_mode=self.memmap_mode)[None]
+ else:
+ segs_from_previous_stage = np.load(self._data[i]['seg_from_prev_stage_file'])['data'][None]
+ # we theoretically support several possible previsous segmentations from which only one is sampled. But
+ # in practice this feature was never used so it's always only one segmentation
+ seg_key = np.random.choice(segs_from_previous_stage.shape[0])
+ seg_from_previous_stage = segs_from_previous_stage[seg_key:seg_key + 1]
+ assert all([i == j for i, j in zip(seg_from_previous_stage.shape[1:], case_all_data.shape[1:])]), \
+ "seg_from_previous_stage does not match the shape of case_all_data: %s vs %s" % \
+ (str(seg_from_previous_stage.shape[1:]), str(case_all_data.shape[1:]))
+ else:
+ seg_from_previous_stage = None
+
+ # do you trust me? You better do. Otherwise you'll have to go through this mess and honestly there are
+ # better things you could do right now
+
+ # (above) documentation of the day. Nice. Even myself coming back 1 months later I have not friggin idea
+ # what's going on. I keep the above documentation just for fun but attempt to make things clearer now
+
+ need_to_pad = self.need_to_pad.copy()
+ for d in range(3):
+ # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
+ # always
+ if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
+ need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
+
+ # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
+ # define what the upper and lower bound can be to then sample from them with np.random.randint
+ shape = case_all_data.shape[1:]
+ lb_x = - need_to_pad[0] // 2
+ ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
+ lb_y = - need_to_pad[1] // 2
+ ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
+ lb_z = - need_to_pad[2] // 2
+ ub_z = shape[2] + need_to_pad[2] // 2 + need_to_pad[2] % 2 - self.patch_size[2]
+
+ # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
+ # at least one of the foreground classes in the patch
+ if not force_fg:
+ bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
+ bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
+ bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
+ else:
+ # these values should have been precomputed
+ if 'class_locations' not in properties.keys():
+ raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!")
+
+ # this saves us a np.unique. Preprocessing already did that for all cases. Neat.
+ foreground_classes = np.array(
+ [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0])
+ foreground_classes = foreground_classes[foreground_classes > 0]
+
+ if len(foreground_classes) == 0:
+ # this only happens if some image does not contain foreground voxels at all
+ selected_class = None
+ voxels_of_that_class = None
+ print('case does not contain any foreground classes', i)
+ else:
+ selected_class = np.random.choice(foreground_classes)
+
+ voxels_of_that_class = properties['class_locations'][selected_class]
+
+ if voxels_of_that_class is not None:
+ selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
+ # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
+ # Make sure it is within the bounds of lb and ub
+ bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
+ bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
+ bbox_z_lb = max(lb_z, selected_voxel[2] - self.patch_size[2] // 2)
+ else:
+ # If the image does not contain any foreground classes, we fall back to random cropping
+ bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
+ bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
+ bbox_z_lb = np.random.randint(lb_z, ub_z + 1)
+
+ bbox_x_ub = bbox_x_lb + self.patch_size[0]
+ bbox_y_ub = bbox_y_lb + self.patch_size[1]
+ bbox_z_ub = bbox_z_lb + self.patch_size[2]
+
+ # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
+ # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
+ # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
+ # later
+ valid_bbox_x_lb = max(0, bbox_x_lb)
+ valid_bbox_x_ub = min(shape[0], bbox_x_ub)
+ valid_bbox_y_lb = max(0, bbox_y_lb)
+ valid_bbox_y_ub = min(shape[1], bbox_y_ub)
+ valid_bbox_z_lb = max(0, bbox_z_lb)
+ valid_bbox_z_ub = min(shape[2], bbox_z_ub)
+
+ # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
+ # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
+ # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
+ # remove label -1 in the data augmentation but this way it is less error prone)
+ case_all_data = np.copy(case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub,
+ valid_bbox_y_lb:valid_bbox_y_ub,
+ valid_bbox_z_lb:valid_bbox_z_ub])
+ if seg_from_previous_stage is not None:
+ seg_from_previous_stage = seg_from_previous_stage[:, valid_bbox_x_lb:valid_bbox_x_ub,
+ valid_bbox_y_lb:valid_bbox_y_ub,
+ valid_bbox_z_lb:valid_bbox_z_ub]
+
+ data[j] = np.pad(case_all_data[:-1], ((0, 0),
+ (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
+ (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
+ (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
+ self.pad_mode, **self.pad_kwargs_data)
+
+ seg[j, 0] = np.pad(case_all_data[-1:], ((0, 0),
+ (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
+ (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0)),
+ (-min(0, bbox_z_lb), max(bbox_z_ub - shape[2], 0))),
+ 'constant', **{'constant_values': -1})
+ if seg_from_previous_stage is not None:
+ seg[j, 1] = np.pad(seg_from_previous_stage, ((0, 0),
+ (-min(0, bbox_x_lb),
+ max(bbox_x_ub - shape[0], 0)),
+ (-min(0, bbox_y_lb),
+ max(bbox_y_ub - shape[1], 0)),
+ (-min(0, bbox_z_lb),
+ max(bbox_z_ub - shape[2], 0))),
+ 'constant', **{'constant_values': 0})
+
+ return {'data': data, 'seg': seg, 'properties': case_properties, 'keys': selected_keys}
+
+
+class DataLoader2D(SlimDataLoaderBase):
+ def __init__(self, data, patch_size, final_patch_size, batch_size, oversample_foreground_percent=0.0,
+ memmap_mode="r", pseudo_3d_slices=1, pad_mode="edge",
+ pad_kwargs_data=None, pad_sides=None):
+ """
+ This is the basic data loader for 2D networks. It uses preprocessed data as produced by my (Fabian) preprocessing.
+ You can load the data with load_dataset(folder) where folder is the folder where the npz files are located. If there
+ are only npz files present in that folder, the data loader will unpack them on the fly. This may take a while
+ and increase CPU usage. Therefore, I advise you to call unpack_dataset(folder) first, which will unpack all npz
+ to npy. Don't forget to call delete_npy(folder) after you are done with training?
+ Why all the hassle? Well the decathlon dataset is huge. Using npy for everything will consume >1 TB and that is uncool
+ given that I (Fabian) will have to store that permanently on /datasets and my local computer. With this strategy all
+ data is stored in a compressed format (factor 10 smaller) and only unpacked when needed.
+ :param data: get this with load_dataset(folder, stage=0). Plug the return value in here and you are g2g (good to go)
+ :param patch_size: what patch size will this data loader return? it is common practice to first load larger
+ patches so that a central crop after data augmentation can be done to reduce border artifacts. If unsure, use
+ get_patch_size() from data_augmentation.default_data_augmentation
+ :param final_patch_size: what will the patch finally be cropped to (after data augmentation)? this is the patch
+ size that goes into your network. We need this here because we will pad patients in here so that patches at the
+ border of patients are sampled properly
+ :param batch_size:
+ :param num_batches: how many batches will the data loader produce before stopping? None=endless
+ :param seed:
+ :param stage: ignore this (Fabian only)
+ :param transpose: ignore this
+ :param random: sample randomly; CAREFUL! non-random sampling requires batch_size=1, otherwise you will iterate batch_size times over the dataset
+ :param pseudo_3d_slices: 7 = 3 below and 3 above the center slice
+ """
+ super(DataLoader2D, self).__init__(data, batch_size, None)
+ if pad_kwargs_data is None:
+ pad_kwargs_data = OrderedDict()
+ self.pad_kwargs_data = pad_kwargs_data
+ self.pad_mode = pad_mode
+ self.pseudo_3d_slices = pseudo_3d_slices
+ self.oversample_foreground_percent = oversample_foreground_percent
+ self.final_patch_size = final_patch_size
+ self.patch_size = patch_size
+ self.list_of_keys = list(self._data.keys())
+ self.need_to_pad = np.array(patch_size) - np.array(final_patch_size)
+ self.memmap_mode = memmap_mode
+ if pad_sides is not None:
+ if not isinstance(pad_sides, np.ndarray):
+ pad_sides = np.array(pad_sides)
+ self.need_to_pad += pad_sides
+ self.pad_sides = pad_sides
+ self.data_shape, self.seg_shape = self.determine_shapes()
+
+ def determine_shapes(self):
+ num_seg = 1
+
+ k = list(self._data.keys())[0]
+ if isfile(self._data[k]['data_file'][:-4] + ".npy"):
+ case_all_data = np.load(self._data[k]['data_file'][:-4] + ".npy", self.memmap_mode)
+ else:
+ case_all_data = np.load(self._data[k]['data_file'])['data']
+ num_color_channels = case_all_data.shape[0] - num_seg
+ data_shape = (self.batch_size, num_color_channels, *self.patch_size)
+ seg_shape = (self.batch_size, num_seg, *self.patch_size)
+ return data_shape, seg_shape
+
+ def get_do_oversample(self, batch_idx):
+ return not batch_idx < round(self.batch_size * (1 - self.oversample_foreground_percent))
+
+ def generate_train_batch(self):
+ selected_keys = np.random.choice(self.list_of_keys, self.batch_size, True, None)
+
+ data = np.zeros(self.data_shape, dtype=np.float32)
+ seg = np.zeros(self.seg_shape, dtype=np.float32)
+
+ case_properties = []
+ for j, i in enumerate(selected_keys):
+ if 'properties' in self._data[i].keys():
+ properties = self._data[i]['properties']
+ else:
+ properties = load_pickle(self._data[i]['properties_file'])
+ case_properties.append(properties)
+
+ if self.get_do_oversample(j):
+ force_fg = True
+ else:
+ force_fg = False
+
+ if not isfile(self._data[i]['data_file'][:-4] + ".npy"):
+ # lets hope you know what you're doing
+ case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npz")['data']
+ else:
+ case_all_data = np.load(self._data[i]['data_file'][:-4] + ".npy", self.memmap_mode)
+
+ # this is for when there is just a 2d slice in case_all_data (2d support)
+ if len(case_all_data.shape) == 3:
+ case_all_data = case_all_data[:, None]
+
+ # first select a slice. This can be either random (no force fg) or guaranteed to contain some class
+ if not force_fg:
+ random_slice = np.random.choice(case_all_data.shape[1])
+ selected_class = None
+ else:
+ # these values should have been precomputed
+ if 'class_locations' not in properties.keys():
+ raise RuntimeError("Please rerun the preprocessing with the newest version of nnU-Net!")
+
+ foreground_classes = np.array(
+ [i for i in properties['class_locations'].keys() if len(properties['class_locations'][i]) != 0])
+ foreground_classes = foreground_classes[foreground_classes > 0]
+ if len(foreground_classes) == 0:
+ selected_class = None
+ random_slice = np.random.choice(case_all_data.shape[1])
+ print('case does not contain any foreground classes', i)
+ else:
+ selected_class = np.random.choice(foreground_classes)
+
+ voxels_of_that_class = properties['class_locations'][selected_class]
+ valid_slices = np.unique(voxels_of_that_class[:, 0])
+ random_slice = np.random.choice(valid_slices)
+ voxels_of_that_class = voxels_of_that_class[voxels_of_that_class[:, 0] == random_slice]
+ voxels_of_that_class = voxels_of_that_class[:, 1:]
+
+ # now crop case_all_data to contain just the slice of interest. If we want additional slice above and
+ # below the current slice, here is where we get them. We stack those as additional color channels
+ if self.pseudo_3d_slices == 1:
+ case_all_data = case_all_data[:, random_slice]
+ else:
+ # this is very deprecated and will probably not work anymore. If you intend to use this you need to
+ # check this!
+ mn = random_slice - (self.pseudo_3d_slices - 1) // 2
+ mx = random_slice + (self.pseudo_3d_slices - 1) // 2 + 1
+ valid_mn = max(mn, 0)
+ valid_mx = min(mx, case_all_data.shape[1])
+ case_all_seg = case_all_data[-1:]
+ case_all_data = case_all_data[:-1]
+ case_all_data = case_all_data[:, valid_mn:valid_mx]
+ case_all_seg = case_all_seg[:, random_slice]
+ need_to_pad_below = valid_mn - mn
+ need_to_pad_above = mx - valid_mx
+ if need_to_pad_below > 0:
+ shp_for_pad = np.array(case_all_data.shape)
+ shp_for_pad[1] = need_to_pad_below
+ case_all_data = np.concatenate((np.zeros(shp_for_pad), case_all_data), 1)
+ if need_to_pad_above > 0:
+ shp_for_pad = np.array(case_all_data.shape)
+ shp_for_pad[1] = need_to_pad_above
+ case_all_data = np.concatenate((case_all_data, np.zeros(shp_for_pad)), 1)
+ case_all_data = case_all_data.reshape((-1, case_all_data.shape[-2], case_all_data.shape[-1]))
+ case_all_data = np.concatenate((case_all_data, case_all_seg), 0)
+
+ # case all data should now be (c, x, y)
+ assert len(case_all_data.shape) == 3
+
+ # we can now choose the bbox from -need_to_pad // 2 to shape - patch_size + need_to_pad // 2. Here we
+ # define what the upper and lower bound can be to then sample from them with np.random.randint
+
+ need_to_pad = self.need_to_pad.copy()
+ for d in range(2):
+ # if case_all_data.shape + need_to_pad is still < patch size we need to pad more! We pad on both sides
+ # always
+ if need_to_pad[d] + case_all_data.shape[d + 1] < self.patch_size[d]:
+ need_to_pad[d] = self.patch_size[d] - case_all_data.shape[d + 1]
+
+ shape = case_all_data.shape[1:]
+ lb_x = - need_to_pad[0] // 2
+ ub_x = shape[0] + need_to_pad[0] // 2 + need_to_pad[0] % 2 - self.patch_size[0]
+ lb_y = - need_to_pad[1] // 2
+ ub_y = shape[1] + need_to_pad[1] // 2 + need_to_pad[1] % 2 - self.patch_size[1]
+
+ # if not force_fg then we can just sample the bbox randomly from lb and ub. Else we need to make sure we get
+ # at least one of the foreground classes in the patch
+ if not force_fg or selected_class is None:
+ bbox_x_lb = np.random.randint(lb_x, ub_x + 1)
+ bbox_y_lb = np.random.randint(lb_y, ub_y + 1)
+ else:
+ # this saves us a np.unique. Preprocessing already did that for all cases. Neat.
+ selected_voxel = voxels_of_that_class[np.random.choice(len(voxels_of_that_class))]
+ # selected voxel is center voxel. Subtract half the patch size to get lower bbox voxel.
+ # Make sure it is within the bounds of lb and ub
+ bbox_x_lb = max(lb_x, selected_voxel[0] - self.patch_size[0] // 2)
+ bbox_y_lb = max(lb_y, selected_voxel[1] - self.patch_size[1] // 2)
+
+ bbox_x_ub = bbox_x_lb + self.patch_size[0]
+ bbox_y_ub = bbox_y_lb + self.patch_size[1]
+
+ # whoever wrote this knew what he was doing (hint: it was me). We first crop the data to the region of the
+ # bbox that actually lies within the data. This will result in a smaller array which is then faster to pad.
+ # valid_bbox is just the coord that lied within the data cube. It will be padded to match the patch size
+ # later
+ valid_bbox_x_lb = max(0, bbox_x_lb)
+ valid_bbox_x_ub = min(shape[0], bbox_x_ub)
+ valid_bbox_y_lb = max(0, bbox_y_lb)
+ valid_bbox_y_ub = min(shape[1], bbox_y_ub)
+
+ # At this point you might ask yourself why we would treat seg differently from seg_from_previous_stage.
+ # Why not just concatenate them here and forget about the if statements? Well that's because segneeds to
+ # be padded with -1 constant whereas seg_from_previous_stage needs to be padded with 0s (we could also
+ # remove label -1 in the data augmentation but this way it is less error prone)
+
+ case_all_data = case_all_data[:, valid_bbox_x_lb:valid_bbox_x_ub,
+ valid_bbox_y_lb:valid_bbox_y_ub]
+
+ case_all_data_donly = np.pad(case_all_data[:-1], ((0, 0),
+ (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
+ (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))),
+ self.pad_mode, **self.pad_kwargs_data)
+
+ case_all_data_segonly = np.pad(case_all_data[-1:], ((0, 0),
+ (-min(0, bbox_x_lb), max(bbox_x_ub - shape[0], 0)),
+ (-min(0, bbox_y_lb), max(bbox_y_ub - shape[1], 0))),
+ 'constant', **{'constant_values': -1})
+
+ data[j] = case_all_data_donly
+ seg[j] = case_all_data_segonly
+
+ keys = selected_keys
+ return {'data': data, 'seg': seg, 'properties': case_properties, "keys": keys}
+
+
+if __name__ == "__main__":
+ t = "Task002_Heart"
+ p = join(preprocessing_output_dir, t, "stage1")
+ dataset = load_dataset(p)
+ with open(join(join(preprocessing_output_dir, t), "plans_stage1.pkl"), 'rb') as f:
+ plans = pickle.load(f)
+ unpack_dataset(p)
+ dl = DataLoader3D(dataset, (32, 32, 32), (32, 32, 32), 2, oversample_foreground_percent=0.33)
+ dl = DataLoader3D(dataset, np.array(plans['patch_size']).astype(int), np.array(plans['patch_size']).astype(int), 2,
+ oversample_foreground_percent=0.33)
+ dl2d = DataLoader2D(dataset, (64, 64), np.array(plans['patch_size']).astype(int)[1:], 12,
+ oversample_foreground_percent=0.33)
diff --git a/nnunet/training/learning_rate/__init__.py b/nnunet/training/learning_rate/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/learning_rate/poly_lr.py b/nnunet/training/learning_rate/poly_lr.py
new file mode 100644
index 0000000000000000000000000000000000000000..7691d78c1d2917fe56e44341f295698a17338856
--- /dev/null
+++ b/nnunet/training/learning_rate/poly_lr.py
@@ -0,0 +1,17 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def poly_lr(epoch, max_epochs, initial_lr, exponent=0.9):
+ return initial_lr * (1 - epoch / max_epochs)**exponent
diff --git a/nnunet/training/loss_functions/TopK_loss.py b/nnunet/training/loss_functions/TopK_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..322305ec76691926d1fed9140ad0a4f5ab1dc21e
--- /dev/null
+++ b/nnunet/training/loss_functions/TopK_loss.py
@@ -0,0 +1,33 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+
+
+class TopKLoss(RobustCrossEntropyLoss):
+ """
+ Network has to have NO LINEARITY!
+ """
+ def __init__(self, weight=None, ignore_index=-100, k=10):
+ self.k = k
+ super(TopKLoss, self).__init__(weight, False, ignore_index, reduce=False)
+
+ def forward(self, inp, target):
+ target = target[:, 0].long()
+ res = super(TopKLoss, self).forward(inp, target)
+ num_voxels = np.prod(res.shape, dtype=np.int64)
+ res, _ = torch.topk(res.view((-1, )), int(num_voxels * self.k / 100), sorted=False)
+ return res.mean()
diff --git a/nnunet/training/loss_functions/__init__.py b/nnunet/training/loss_functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/loss_functions/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/loss_functions/crossentropy.py b/nnunet/training/loss_functions/crossentropy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6195437b452a5caa0a61cfafa997e55a2a510ee7
--- /dev/null
+++ b/nnunet/training/loss_functions/crossentropy.py
@@ -0,0 +1,12 @@
+from torch import nn, Tensor
+
+
+class RobustCrossEntropyLoss(nn.CrossEntropyLoss):
+ """
+ this is just a compatibility layer because my target tensor is float and has an extra dimension
+ """
+ def forward(self, input: Tensor, target: Tensor) -> Tensor:
+ if len(target.shape) == len(input.shape):
+ assert target.shape[1] == 1
+ target = target[:, 0]
+ return super().forward(input, target.long())
\ No newline at end of file
diff --git a/nnunet/training/loss_functions/deep_supervision.py b/nnunet/training/loss_functions/deep_supervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..04dc465fe9215b1c1ae04770ea4eca0f5cdf4bef
--- /dev/null
+++ b/nnunet/training/loss_functions/deep_supervision.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from torch import nn
+
+
+class MultipleOutputLoss2(nn.Module):
+ def __init__(self, loss, weight_factors=None):
+ """
+ use this if you have several outputs and ground truth (both list of same len) and the loss should be computed
+ between them (x[0] and y[0], x[1] and y[1] etc)
+ :param loss:
+ :param weight_factors:
+ """
+ super(MultipleOutputLoss2, self).__init__()
+ self.weight_factors = weight_factors
+ self.loss = loss
+
+ def forward(self, x, y):
+ assert isinstance(x, (tuple, list)), "x must be either tuple or list"
+ assert isinstance(y, (tuple, list)), "y must be either tuple or list"
+ if self.weight_factors is None:
+ weights = [1] * len(x)
+ else:
+ weights = self.weight_factors
+
+ l = weights[0] * self.loss(x[0], y[0])
+ for i in range(1, len(x)):
+ if weights[i] != 0:
+ l += weights[i] * self.loss(x[i], y[i])
+ return l
diff --git a/nnunet/training/loss_functions/dice_loss.py b/nnunet/training/loss_functions/dice_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..32b4994de6e38da143daf17fbbc301680d26309a
--- /dev/null
+++ b/nnunet/training/loss_functions/dice_loss.py
@@ -0,0 +1,480 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.training.loss_functions.TopK_loss import TopKLoss
+from nnunet.training.loss_functions.focal_loss import FocalLossV2
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.utilities.nd_softmax import softmax_helper
+from nnunet.utilities.tensor_utilities import sum_tensor
+from torch import nn
+import numpy as np
+
+
+class GDL(nn.Module):
+ def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.,
+ square=False, square_volumes=False):
+ """
+ square_volumes will square the weight term. The paper recommends square_volumes=True; I don't (just an intuition)
+ """
+ super(GDL, self).__init__()
+
+ self.square_volumes = square_volumes
+ self.square = square
+ self.do_bg = do_bg
+ self.batch_dice = batch_dice
+ self.apply_nonlin = apply_nonlin
+ self.smooth = smooth
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x = x.shape
+ shp_y = y.shape
+
+ if self.batch_dice:
+ axes = [0] + list(range(2, len(shp_x)))
+ else:
+ axes = list(range(2, len(shp_x)))
+
+ if len(shp_x) != len(shp_y):
+ y = y.view((shp_y[0], 1, *shp_y[1:]))
+
+ if all([i == j for i, j in zip(x.shape, y.shape)]):
+ # if this is the case then gt is probably already a one hot encoding
+ y_onehot = y
+ else:
+ gt = y.long()
+ y_onehot = torch.zeros(shp_x)
+ if x.device.type == "cuda":
+ y_onehot = y_onehot.cuda(x.device.index)
+ y_onehot.scatter_(1, gt, 1)
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ if not self.do_bg:
+ x = x[:, 1:]
+ y_onehot = y_onehot[:, 1:]
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(
+ x, y_onehot, axes, loss_mask, self.square)
+
+ # GDL weight computation, we use 1/V
+ # add some eps to prevent div by zero
+ volumes = sum_tensor(y_onehot, axes) + 1e-6
+
+ if self.square_volumes:
+ volumes = volumes ** 2
+
+ # apply weights
+ tp = tp / volumes
+ fp = fp / volumes
+ fn = fn / volumes
+
+ # sum over classes
+ if self.batch_dice:
+ axis = 0
+ else:
+ axis = 1
+
+ tp = tp.sum(axis, keepdim=False)
+ fp = fp.sum(axis, keepdim=False)
+ fn = fn.sum(axis, keepdim=False)
+
+ # compute dice
+ dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
+
+ dc = dc.mean()
+
+ return -dc
+
+
+def get_tp_fp_fn_tn(net_output, gt, axes=None, mask=None, square=False):
+ """
+ net_output must be (b, c, x, y(, z)))
+ gt must be a label map (shape (b, 1, x, y(, z)) OR shape (b, x, y(, z))) or one hot encoding (b, c, x, y(, z))
+ if mask is provided it must have shape (b, 1, x, y(, z)))
+ :param net_output:
+ :param gt:
+ :param axes: can be (, ) = no summation
+ :param mask: mask must be 1 for valid pixels and 0 for invalid pixels
+ :param square: if True then fp, tp and fn will be squared before summation
+ :return:
+ """
+ if axes is None:
+ axes = tuple(range(2, len(net_output.size())))
+
+ shp_x = net_output.shape
+ shp_y = gt.shape
+
+ with torch.no_grad():
+ if len(shp_x) != len(shp_y):
+ gt = gt.view((shp_y[0], 1, *shp_y[1:]))
+
+ if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
+ # if this is the case then gt is probably already a one hot encoding
+ y_onehot = gt
+ else:
+ gt = gt.long()
+ y_onehot = torch.zeros(shp_x, device=net_output.device)
+ y_onehot.scatter_(1, gt, 1)
+
+ tp = net_output * y_onehot
+ fp = net_output * (1 - y_onehot)
+ fn = (1 - net_output) * y_onehot
+ tn = (1 - net_output) * (1 - y_onehot)
+
+ if mask is not None:
+ tp = torch.stack(tuple(x_i * mask[:, 0]
+ for x_i in torch.unbind(tp, dim=1)), dim=1)
+ fp = torch.stack(tuple(x_i * mask[:, 0]
+ for x_i in torch.unbind(fp, dim=1)), dim=1)
+ fn = torch.stack(tuple(x_i * mask[:, 0]
+ for x_i in torch.unbind(fn, dim=1)), dim=1)
+ tn = torch.stack(tuple(x_i * mask[:, 0]
+ for x_i in torch.unbind(tn, dim=1)), dim=1)
+
+ if square:
+ tp = tp ** 2
+ fp = fp ** 2
+ fn = fn ** 2
+ tn = tn ** 2
+
+ if len(axes) > 0:
+ tp = sum_tensor(tp, axes, keepdim=False)
+ fp = sum_tensor(fp, axes, keepdim=False)
+ fn = sum_tensor(fn, axes, keepdim=False)
+ tn = sum_tensor(tn, axes, keepdim=False)
+
+ return tp, fp, fn, tn
+
+
+class SoftDiceLoss(nn.Module):
+ def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
+ """
+ """
+ super(SoftDiceLoss, self).__init__()
+
+ self.do_bg = do_bg
+ self.batch_dice = batch_dice
+ self.apply_nonlin = apply_nonlin
+ self.smooth = smooth
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x = x.shape
+
+ if self.batch_dice:
+ axes = [0] + list(range(2, len(shp_x)))
+ else:
+ axes = list(range(2, len(shp_x)))
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
+
+ nominator = 2 * tp + self.smooth
+ denominator = 2 * tp + fp + fn + self.smooth
+
+ dc = nominator / (denominator + 1e-8)
+
+ if not self.do_bg:
+ if self.batch_dice:
+ dc = dc[1:]
+ else:
+ dc = dc[:, 1:]
+ dc = dc.mean()
+
+ return -dc
+
+
+class MCCLoss(nn.Module):
+ def __init__(self, apply_nonlin=None, batch_mcc=False, do_bg=True, smooth=0.0):
+ """
+ based on matthews correlation coefficient
+ https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
+
+ Does not work. Really unstable. F this.
+ """
+ super(MCCLoss, self).__init__()
+
+ self.smooth = smooth
+ self.do_bg = do_bg
+ self.batch_mcc = batch_mcc
+ self.apply_nonlin = apply_nonlin
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x = x.shape
+ voxels = np.prod(shp_x[2:])
+
+ if self.batch_mcc:
+ axes = [0] + list(range(2, len(shp_x)))
+ else:
+ axes = list(range(2, len(shp_x)))
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ tp, fp, fn, tn = get_tp_fp_fn_tn(x, y, axes, loss_mask, False)
+ tp /= voxels
+ fp /= voxels
+ fn /= voxels
+ tn /= voxels
+
+ nominator = tp * tn - fp * fn + self.smooth
+ denominator = ((tp + fp) * (tp + fn) * (tn + fp)
+ * (tn + fn)) ** 0.5 + self.smooth
+
+ mcc = nominator / denominator
+
+ if not self.do_bg:
+ if self.batch_mcc:
+ mcc = mcc[1:]
+ else:
+ mcc = mcc[:, 1:]
+ mcc = mcc.mean()
+
+ return -mcc
+
+
+class SoftDiceLossSquared(nn.Module):
+ def __init__(self, apply_nonlin=None, batch_dice=False, do_bg=True, smooth=1.):
+ """
+ squares the terms in the denominator as proposed by Milletari et al.
+ """
+ super(SoftDiceLossSquared, self).__init__()
+
+ self.do_bg = do_bg
+ self.batch_dice = batch_dice
+ self.apply_nonlin = apply_nonlin
+ self.smooth = smooth
+
+ def forward(self, x, y, loss_mask=None):
+ shp_x = x.shape
+ shp_y = y.shape
+
+ if self.batch_dice:
+ axes = [0] + list(range(2, len(shp_x)))
+ else:
+ axes = list(range(2, len(shp_x)))
+
+ if self.apply_nonlin is not None:
+ x = self.apply_nonlin(x)
+
+ with torch.no_grad():
+ if len(shp_x) != len(shp_y):
+ y = y.view((shp_y[0], 1, *shp_y[1:]))
+
+ if all([i == j for i, j in zip(x.shape, y.shape)]):
+ # if this is the case then gt is probably already a one hot encoding
+ y_onehot = y
+ else:
+ y = y.long()
+ y_onehot = torch.zeros(shp_x)
+ if x.device.type == "cuda":
+ y_onehot = y_onehot.cuda(x.device.index)
+ y_onehot.scatter_(1, y, 1).float()
+
+ intersect = x * y_onehot
+ # values in the denominator get smoothed
+ denominator = x ** 2 + y_onehot ** 2
+
+ # aggregation was previously done in get_tp_fp_fn, but needs to be done here now (needs to be done after
+ # squaring)
+ intersect = sum_tensor(intersect, axes, False) + self.smooth
+ denominator = sum_tensor(denominator, axes, False) + self.smooth
+
+ dc = 2 * intersect / denominator
+
+ if not self.do_bg:
+ if self.batch_dice:
+ dc = dc[1:]
+ else:
+ dc = dc[:, 1:]
+ dc = dc.mean()
+
+ return -dc
+
+
+class DC_and_CE_loss(nn.Module):
+ def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False, weight_ce=1, weight_dice=1,
+ log_dice=False, ignore_label=None):
+ """
+ CAREFUL. Weights for CE and Dice do not need to sum to one. You can set whatever you want.
+ :param soft_dice_kwargs:
+ :param ce_kwargs:
+ :param aggregate:
+ :param square_dice:
+ :param weight_ce:
+ :param weight_dice:
+ """
+ super(DC_and_CE_loss, self).__init__()
+ if ignore_label is not None:
+ assert not square_dice, 'not implemented'
+ ce_kwargs['reduction'] = 'none'
+ self.log_dice = log_dice
+ self.weight_dice = weight_dice
+ self.weight_ce = weight_ce
+ self.aggregate = aggregate
+ self.ce = RobustCrossEntropyLoss(**ce_kwargs)
+
+ self.ignore_label = ignore_label
+
+ if not square_dice:
+ self.dc = SoftDiceLoss(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+ else:
+ self.dc = SoftDiceLossSquared(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+
+ def forward(self, net_output, target):
+ """
+ target must be b, c, x, y(, z) with c=1
+ :param net_output:
+ :param target:
+ :return:
+ """
+ if self.ignore_label is not None:
+ assert target.shape[1] == 1, 'not implemented for one hot encoding'
+ mask = target != self.ignore_label
+ target[~mask] = 0
+ mask = mask.float()
+ else:
+ mask = None
+
+ dc_loss = self.dc(net_output, target,
+ loss_mask=mask) if self.weight_dice != 0 else 0
+ if self.log_dice:
+ dc_loss = -torch.log(-dc_loss)
+
+ ce_loss = self.ce(
+ net_output, target[:, 0].long()) if self.weight_ce != 0 else 0
+ if self.ignore_label is not None:
+ ce_loss *= mask[:, 0]
+ ce_loss = ce_loss.sum() / mask.sum()
+
+ if self.aggregate == "sum":
+ result = self.weight_ce * ce_loss + self.weight_dice * dc_loss
+ else:
+ # reserved for other stuff (later)
+ raise NotImplementedError("nah son")
+ return result
+
+
+class DC_and_BCE_loss(nn.Module):
+ def __init__(self, bce_kwargs, soft_dice_kwargs, aggregate="sum"):
+ """
+ DO NOT APPLY NONLINEARITY IN YOUR NETWORK!
+
+ THIS LOSS IS INTENDED TO BE USED FOR BRATS REGIONS ONLY
+ :param soft_dice_kwargs:
+ :param bce_kwargs:
+ :param aggregate:
+ """
+ super(DC_and_BCE_loss, self).__init__()
+
+ self.aggregate = aggregate
+ self.ce = nn.BCEWithLogitsLoss(**bce_kwargs)
+ self.dc = SoftDiceLoss(apply_nonlin=torch.sigmoid, **soft_dice_kwargs)
+
+ def forward(self, net_output, target):
+ ce_loss = self.ce(net_output, target)
+ dc_loss = self.dc(net_output, target)
+
+ if self.aggregate == "sum":
+ result = ce_loss + dc_loss
+ else:
+ # reserved for other stuff (later)
+ raise NotImplementedError("nah son")
+
+ return result
+
+
+class GDL_and_CE_loss(nn.Module):
+ def __init__(self, gdl_dice_kwargs, ce_kwargs, aggregate="sum"):
+ super(GDL_and_CE_loss, self).__init__()
+ self.aggregate = aggregate
+ self.ce = RobustCrossEntropyLoss(**ce_kwargs)
+ self.dc = GDL(softmax_helper, **gdl_dice_kwargs)
+
+ def forward(self, net_output, target):
+ dc_loss = self.dc(net_output, target)
+ ce_loss = self.ce(net_output, target)
+ if self.aggregate == "sum":
+ result = ce_loss + dc_loss
+ else:
+ # reserved for other stuff (later)
+ raise NotImplementedError("nah son")
+ return result
+
+
+class DC_and_topk_loss(nn.Module):
+ def __init__(self, soft_dice_kwargs, ce_kwargs, aggregate="sum", square_dice=False):
+ super(DC_and_topk_loss, self).__init__()
+ self.aggregate = aggregate
+ self.ce = TopKLoss(**ce_kwargs)
+ if not square_dice:
+ self.dc = SoftDiceLoss(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+ else:
+ self.dc = SoftDiceLossSquared(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+
+ def forward(self, net_output, target):
+ dc_loss = self.dc(net_output, target)
+ ce_loss = self.ce(net_output, target)
+ if self.aggregate == "sum":
+ result = ce_loss + dc_loss
+ else:
+ # reserved for other stuff (later?)
+ raise NotImplementedError("nah son")
+ return result
+
+
+class DC_and_Focal_loss(nn.Module):
+ def __init__(self, soft_dice_kwargs, focal_kwargs, aggregate="sum", square_dice=False, weight_focal=1, weight_dice=1, log_dice=False):
+ super(DC_and_Focal_loss, self).__init__()
+ self.aggregate = aggregate
+ self.focal = FocalLossV2(apply_nonlin=softmax_helper, **focal_kwargs)
+ self.log_dice = log_dice
+ self.weight_focal = weight_focal
+ self.weight_dice = weight_dice
+
+ if not square_dice:
+ self.dc = SoftDiceLoss(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+ else:
+ self.dc = SoftDiceLossSquared(
+ apply_nonlin=softmax_helper, **soft_dice_kwargs)
+
+ def forward(self, net_output, target):
+ """
+ target must be b, c, x, y(, z) with c=1
+ :param net_output:
+ :param target:
+ :return:
+ """
+ dc_loss = self.dc(net_output, target) if self.weight_dice != 0 else 0
+ if self.log_dice:
+ dc_loss = -torch.log(-dc_loss)
+
+ focal_loss = self.focal(
+ net_output, target[:, 0].long()) if self.weight_focal != 0 else 0
+
+ if self.aggregate == "sum":
+ result = self.weight_focal * focal_loss + self.weight_dice * dc_loss
+ else:
+ # reserved for other stuff (later)
+ raise NotImplementedError("nah son")
+ return result
diff --git a/nnunet/training/loss_functions/focal_loss.py b/nnunet/training/loss_functions/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..65cebe1ed23bfb034054657968e5075e405b74da
--- /dev/null
+++ b/nnunet/training/loss_functions/focal_loss.py
@@ -0,0 +1,195 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+from torch import nn
+from nnunet.utilities.nd_softmax import softmax_helper
+#from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+# taken from https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions/focal_loss.py
+class FocalLoss(nn.Module):
+ """
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
+ :param num_class:
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
+ focus on hard misclassified example
+ :param smooth: (float,double) smooth value when cross entropy
+ :param balance_index: (int) balance class index, should be specific when alpha is float
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
+ """
+
+ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
+ super(FocalLoss, self).__init__()
+ self.apply_nonlin = apply_nonlin
+ self.alpha = alpha
+ self.gamma = gamma
+ self.balance_index = balance_index
+ self.smooth = smooth
+ self.size_average = size_average
+
+ if self.smooth is not None:
+ if self.smooth < 0 or self.smooth > 1.0:
+ raise ValueError('smooth value should be in [0,1]')
+
+ def forward(self, logit, target):
+ if self.apply_nonlin is not None:
+ logit = self.apply_nonlin(logit)
+ num_class = logit.shape[1]
+
+ if logit.dim() > 2:
+ # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
+ logit = logit.view(logit.size(0), logit.size(1), -1)
+ logit = logit.permute(0, 2, 1).contiguous()
+ logit = logit.view(-1, logit.size(-1))
+ target = torch.squeeze(target, 1)
+ target = target.view(-1, 1)
+ # print(logit.shape, target.shape)
+ #
+ alpha = self.alpha
+
+ if alpha is None:
+ alpha = torch.ones(num_class, 1)
+ elif isinstance(alpha, (list, np.ndarray)):
+ assert len(alpha) == num_class
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
+ alpha = alpha / alpha.sum()
+ elif isinstance(alpha, float):
+ alpha = torch.ones(num_class, 1)
+ alpha = alpha * (1 - self.alpha)
+ alpha[self.balance_index] = self.alpha
+
+ else:
+ raise TypeError('Not support alpha type')
+
+ if alpha.device != logit.device:
+ alpha = alpha.to(logit.device)
+
+ idx = target.cpu().long()
+
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
+ if one_hot_key.device != logit.device:
+ one_hot_key = one_hot_key.to(logit.device)
+
+ if self.smooth:
+ one_hot_key = torch.clamp(
+ one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
+ pt = (one_hot_key * logit).sum(1) + self.smooth
+ logpt = pt.log()
+
+ gamma = self.gamma
+
+ alpha = alpha[idx]
+ alpha = torch.squeeze(alpha)
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
+
+ if self.size_average:
+ loss = loss.mean()
+ else:
+ loss = loss.sum()
+ return loss
+
+
+# taken from https://github.com/JunMa11/SegLoss/blob/master/test/nnUNetV2/loss_functions/focal_loss.py
+class FocalLossV2(nn.Module):
+ """
+ copy from: https://github.com/Hsuxu/Loss_ToolBox-PyTorch/blob/master/FocalLoss/FocalLoss.py
+ This is a implementation of Focal Loss with smooth label cross entropy supported which is proposed in
+ 'Focal Loss for Dense Object Detection. (https://arxiv.org/abs/1708.02002)'
+ Focal_Loss= -1*alpha*(1-pt)*log(pt)
+ :param num_class:
+ :param alpha: (tensor) 3D or 4D the scalar factor for this criterion
+ :param gamma: (float,double) gamma > 0 reduces the relative loss for well-classified examples (p>0.5) putting more
+ focus on hard misclassified example
+ :param smooth: (float,double) smooth value when cross entropy
+ :param balance_index: (int) balance class index, should be specific when alpha is float
+ :param size_average: (bool, optional) By default, the losses are averaged over each loss element in the batch.
+ """
+
+ def __init__(self, apply_nonlin=None, alpha=None, gamma=2, balance_index=0, smooth=1e-5, size_average=True):
+ super(FocalLossV2, self).__init__()
+ self.apply_nonlin = apply_nonlin
+ self.alpha = alpha
+ self.gamma = gamma
+ self.balance_index = balance_index
+ self.smooth = smooth
+ self.size_average = size_average
+
+ if self.smooth is not None:
+ if self.smooth < 0 or self.smooth > 1.0:
+ raise ValueError('smooth value should be in [0,1]')
+
+ def forward(self, logit, target):
+ if self.apply_nonlin is not None:
+ logit = self.apply_nonlin(logit)
+ num_class = logit.shape[1]
+
+ if logit.dim() > 2:
+ # N,C,d1,d2 -> N,C,m (m=d1*d2*...)
+ logit = logit.view(logit.size(0), logit.size(1), -1)
+ logit = logit.permute(0, 2, 1).contiguous()
+ logit = logit.view(-1, logit.size(-1))
+ target = torch.squeeze(target, 1)
+ target = target.view(-1, 1)
+ # print(logit.shape, target.shape)
+ #
+ alpha = self.alpha
+
+ if alpha is None:
+ alpha = torch.ones(num_class, 1)
+ elif isinstance(alpha, (list, np.ndarray)):
+ assert len(alpha) == num_class
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
+ alpha = alpha / alpha.sum()
+ elif isinstance(alpha, float):
+ alpha = torch.ones(num_class, 1)
+ alpha = alpha * (1 - self.alpha)
+ alpha[self.balance_index] = self.alpha
+
+ else:
+ raise TypeError('Not support alpha type')
+
+ if alpha.device != logit.device:
+ alpha = alpha.to(logit.device)
+
+ idx = target.cpu().long()
+
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
+ if one_hot_key.device != logit.device:
+ one_hot_key = one_hot_key.to(logit.device)
+
+ if self.smooth:
+ one_hot_key = torch.clamp(
+ one_hot_key, self.smooth / (num_class - 1), 1.0 - self.smooth)
+ pt = (one_hot_key * logit).sum(1) + self.smooth
+ logpt = pt.log()
+
+ gamma = self.gamma
+
+ alpha = alpha[idx]
+ alpha = torch.squeeze(alpha)
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
+
+ if self.size_average:
+ loss = loss.mean()
+ else:
+ loss = loss.sum()
+ return loss
diff --git a/nnunet/training/model_restore.py b/nnunet/training/model_restore.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2be12d54e756ac955056fec6df27b5d6fd247bd
--- /dev/null
+++ b/nnunet/training/model_restore.py
@@ -0,0 +1,155 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import nnunet
+import torch
+from batchgenerators.utilities.file_and_folder_operations import *
+import importlib
+import pkgutil
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+
+
+def recursive_find_python_class(folder, trainer_name, current_module):
+ tr = None
+ for importer, modname, ispkg in pkgutil.iter_modules(folder):
+ # print(modname, ispkg)
+ if not ispkg:
+ m = importlib.import_module(current_module + "." + modname)
+ if hasattr(m, trainer_name):
+ tr = getattr(m, trainer_name)
+ break
+
+ if tr is None:
+ for importer, modname, ispkg in pkgutil.iter_modules(folder):
+ if ispkg:
+ next_current_module = current_module + "." + modname
+ tr = recursive_find_python_class([join(folder[0], modname)], trainer_name, current_module=next_current_module)
+ if tr is not None:
+ break
+
+ return tr
+
+
+def restore_model(pkl_file, checkpoint=None, train=False, fp16=None):
+ """
+ This is a utility function to load any nnUNet trainer from a pkl. It will recursively search
+ nnunet.trainig.network_training for the file that contains the trainer and instantiate it with the arguments saved in the pkl file. If checkpoint
+ is specified, it will furthermore load the checkpoint file in train/test mode (as specified by train).
+ The pkl file required here is the one that will be saved automatically when calling nnUNetTrainer.save_checkpoint.
+ :param pkl_file:
+ :param checkpoint:
+ :param train:
+ :param fp16: if None then we take no action. If True/False we overwrite what the model has in its init
+ :return:
+ """
+ info = load_pickle(pkl_file)
+ init = info['init']
+ name = info['name']
+ search_in = join(nnunet.__path__[0], "training", "network_training")
+ tr = recursive_find_python_class([search_in], name, current_module="nnunet.training.network_training")
+
+ if tr is None:
+ """
+ Fabian only. This will trigger searching for trainer classes in other repositories as well
+ """
+ try:
+ import meddec
+ search_in = join(meddec.__path__[0], "model_training")
+ tr = recursive_find_python_class([search_in], name, current_module="meddec.model_training")
+ except ImportError:
+ pass
+
+ if tr is None:
+ raise RuntimeError("Could not find the model trainer specified in checkpoint in nnunet.trainig.network_training. If it "
+ "is not located there, please move it or change the code of restore_model. Your model "
+ "trainer can be located in any directory within nnunet.trainig.network_training (search is recursive)."
+ "\nDebug info: \ncheckpoint file: %s\nName of trainer: %s " % (checkpoint, name))
+ assert issubclass(tr, nnUNetTrainer), "The network trainer was found but is not a subclass of nnUNetTrainer. " \
+ "Please make it so!"
+
+ # this is now deprecated
+ """if len(init) == 7:
+ print("warning: this model seems to have been saved with a previous version of nnUNet. Attempting to load it "
+ "anyways. Expect the unexpected.")
+ print("manually editing init args...")
+ init = [init[i] for i in range(len(init)) if i != 2]"""
+
+ # ToDo Fabian make saves use kwargs, please...
+
+ trainer = tr(*init)
+
+ # We can hack fp16 overwriting into the trainer without changing the init arguments because nothing happens with
+ # fp16 in the init, it just saves it to a member variable
+ if fp16 is not None:
+ trainer.fp16 = fp16
+
+ trainer.process_plans(info['plans'])
+ if checkpoint is not None:
+ trainer.load_checkpoint(checkpoint, train)
+ return trainer
+
+
+def load_best_model_for_inference(folder):
+ checkpoint = join(folder, "model_best.model")
+ pkl_file = checkpoint + ".pkl"
+ return restore_model(pkl_file, checkpoint, False)
+
+
+def load_model_and_checkpoint_files(folder, folds=None, mixed_precision=None, checkpoint_name="model_best"):
+ """
+ used for if you need to ensemble the five models of a cross-validation. This will restore the model from the
+ checkpoint in fold 0, load all parameters of the five folds in ram and return both. This will allow for fast
+ switching between parameters (as opposed to loading them from disk each time).
+
+ This is best used for inference and test prediction
+ :param folder:
+ :param folds:
+ :param mixed_precision: if None then we take no action. If True/False we overwrite what the model has in its init
+ :return:
+ """
+ if isinstance(folds, str):
+ folds = [join(folder, "all")]
+ assert isdir(folds[0]), "no output folder for fold %s found" % folds
+ elif isinstance(folds, (list, tuple)):
+ if len(folds) == 1 and folds[0] == "all":
+ folds = [join(folder, "all")]
+ else:
+ folds = [join(folder, "fold_%d" % i) for i in folds]
+ assert all([isdir(i) for i in folds]), "list of folds specified but not all output folders are present"
+ elif isinstance(folds, int):
+ folds = [join(folder, "fold_%d" % folds)]
+ assert all([isdir(i) for i in folds]), "output folder missing for fold %d" % folds
+ elif folds is None:
+ print("folds is None so we will automatically look for output folders (not using \'all\'!)")
+ folds = subfolders(folder, prefix="fold")
+ print("found the following folds: ", folds)
+ else:
+ raise ValueError("Unknown value for folds. Type: %s. Expected: list of int, int, str or None", str(type(folds)))
+
+ trainer = restore_model(join(folds[0], "%s.model.pkl" % checkpoint_name), fp16=mixed_precision)
+ trainer.output_folder = folder
+ trainer.output_folder_base = folder
+ trainer.update_fold(0)
+ trainer.initialize(False)
+ all_best_model_files = [join(i, "%s.model" % checkpoint_name) for i in folds]
+ print("using the following model files: ", all_best_model_files)
+ all_params = [torch.load(i, map_location=torch.device('cpu')) for i in all_best_model_files]
+ return trainer, all_params
+
+
+if __name__ == "__main__":
+ pkl = "/home/fabian/PhD/results/nnUNetV2/nnUNetV2_3D_fullres/Task004_Hippocampus/fold0/model_best.model.pkl"
+ checkpoint = pkl[:-4]
+ train = False
+ trainer = restore_model(pkl, checkpoint, train)
diff --git a/nnunet/training/network_training/.DS_Store b/nnunet/training/network_training/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..a84493cfa3653db182076c604d030eb0684c6534
Binary files /dev/null and b/nnunet/training/network_training/.DS_Store differ
diff --git a/nnunet/training/network_training/__init__.py b/nnunet/training/network_training/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/network_training/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/__init__.py b/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions.py b/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions.py
new file mode 100644
index 0000000000000000000000000000000000000000..a462c13758166404c8d69190ec6d573098ff9d97
--- /dev/null
+++ b/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions.py
@@ -0,0 +1,420 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from time import sleep
+
+import numpy as np
+import torch
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from torch import nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.nn.utils import clip_grad_norm_
+
+from nnunet.evaluation.region_based_evaluation import evaluate_regions, get_brats_regions
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.loss_functions.dice_loss import DC_and_BCE_loss, get_tp_fp_fn_tn, SoftDiceLoss
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.network_training.nnUNetTrainerV2_DDP import nnUNetTrainerV2_DDP
+from nnunet.utilities.distributed import awesome_allgather_function
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+
+
+class nnUNetTrainerV2BraTSRegions_BN(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.BatchNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.BatchNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = torch.nn.Softmax(1)
+
+
+class nnUNetTrainerV2BraTSRegions(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.regions = get_brats_regions()
+ self.regions_class_order = (1, 2, 3)
+ self.loss = DC_and_BCE_loss({}, {'batch_dice': False, 'do_bg': True, 'smooth': 0})
+
+ def process_plans(self, plans):
+ super().process_plans(plans)
+ """
+ The network has as many outputs as we have regions
+ """
+ self.num_classes = len(self.regions)
+
+ def initialize_network(self):
+ """inference_apply_nonlin to sigmoid"""
+ super().initialize_network()
+ self.network.inference_apply_nonlin = nn.Sigmoid()
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ regions=self.regions)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: int = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ # run brats specific validation
+ output_folder = join(self.output_folder, validation_folder_name)
+ evaluate_regions(output_folder, self.gt_niftis_folder, self.regions)
+
+ def run_online_evaluation(self, output, target):
+ output = output[0]
+ target = target[0]
+ with torch.no_grad():
+ out_sigmoid = torch.sigmoid(output)
+ out_sigmoid = (out_sigmoid > 0.5).float()
+
+ if self.threeD:
+ axes = (0, 2, 3, 4)
+ else:
+ axes = (0, 2, 3)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+
+ self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
+ self.online_eval_tp.append(list(tp_hard))
+ self.online_eval_fp.append(list(fp_hard))
+ self.online_eval_fn.append(list(fn_hard))
+
+
+class nnUNetTrainerV2BraTSRegions_Dice(nnUNetTrainerV2BraTSRegions):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = SoftDiceLoss(apply_nonlin=torch.sigmoid, **{'batch_dice': False, 'do_bg': True, 'smooth': 0})
+
+
+class nnUNetTrainerV2BraTSRegions_DDP(nnUNetTrainerV2_DDP):
+ def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True,
+ stage=None,
+ unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False):
+ super().__init__(plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, distribute_batch_size, fp16)
+ self.regions = get_brats_regions()
+ self.regions_class_order = (1, 2, 3)
+ self.loss = None
+ self.ce_loss = nn.BCEWithLogitsLoss()
+
+ def process_plans(self, plans):
+ super().process_plans(plans)
+ """
+ The network has as many outputs as we have regions
+ """
+ self.num_classes = len(self.regions)
+
+ def initialize_network(self):
+ """inference_apply_nonlin to sigmoid"""
+ super().initialize_network()
+ self.network.inference_apply_nonlin = nn.Sigmoid()
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ if self.local_rank == 0:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ # we need to wait until worker 0 has finished unpacking
+ npz_files = subfiles(self.folder_with_preprocessed_data, suffix=".npz", join=False)
+ case_ids = [i[:-4] for i in npz_files]
+ all_present = all(
+ [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
+ while not all_present:
+ print("worker", self.local_rank, "is waiting for unpacking")
+ sleep(3)
+ all_present = all(
+ [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids])
+ # there is some slight chance that there may arise some error because dataloader are loading a file
+ # that is still being written by worker 0. We ignore this for now an address it only if it becomes
+ # relevant
+ # (this can occur because while worker 0 writes the file is technically present so the other workers
+ # will proceed and eventually try to read it)
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ # setting weights for deep supervision losses
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+
+ seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
+ seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
+ print("seeds train", seeds_train)
+ print("seeds_val", seeds_val)
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ seeds_train=seeds_train,
+ seeds_val=seeds_val,
+ pin_memory=self.pin_memory,
+ regions=self.regions)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+ self._maybe_init_amp()
+ self.network = DDP(self.network, self.local_rank)
+
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: int = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ # run brats specific validation
+ output_folder = join(self.output_folder, validation_folder_name)
+ evaluate_regions(output_folder, self.gt_niftis_folder, self.regions)
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ raise NotImplementedError("this class has not been changed to work with pytorch amp yet!")
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data, gpu_id=None)
+ target = to_cuda(target, gpu_id=None)
+
+ self.optimizer.zero_grad()
+
+ output = self.network(data)
+ del data
+
+ total_loss = None
+
+ for i in range(len(output)):
+ # Starting here it gets spicy!
+ axes = tuple(range(2, len(output[i].size())))
+
+ # network does not do softmax. We need to do softmax for dice
+ output_softmax = torch.sigmoid(output[i])
+
+ # get the tp, fp and fn terms we need
+ tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
+ # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
+ # do_bg=False in nnUNetTrainer -> [:, 1:]
+ nominator = 2 * tp[:, 1:]
+ denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]
+
+ if self.batch_dice:
+ # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
+ nominator = awesome_allgather_function.apply(nominator)
+ denominator = awesome_allgather_function.apply(denominator)
+ nominator = nominator.sum(0)
+ denominator = denominator.sum(0)
+ else:
+ pass
+
+ ce_loss = self.ce_loss(output[i], target[i])
+
+ # we smooth by 1e-5 to penalize false positives if tp is 0
+ dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
+ if total_loss is None:
+ total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
+ else:
+ total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)
+
+ if run_online_evaluation:
+ with torch.no_grad():
+ output = output[0]
+ target = target[0]
+ out_sigmoid = torch.sigmoid(output)
+ out_sigmoid = (out_sigmoid > 0.5).float()
+
+ if self.threeD:
+ axes = (2, 3, 4)
+ else:
+ axes = (2, 3)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)
+
+ tp_hard = awesome_allgather_function.apply(tp)
+ fp_hard = awesome_allgather_function.apply(fp)
+ fn_hard = awesome_allgather_function.apply(fn)
+ # print_if_rank0("after allgather", tp_hard.shape)
+
+ # print_if_rank0("after sum", tp_hard.shape)
+
+ self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0),
+ fp_hard.detach().cpu().numpy().sum(0),
+ fn_hard.detach().cpu().numpy().sum(0))
+ del target
+
+ if do_backprop:
+ if not self.fp16 or amp is None or not torch.cuda.is_available():
+ total_loss.backward()
+ else:
+ with amp.scale_loss(total_loss, self.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ _ = clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ return total_loss.detach().cpu().numpy()
+
+ def run_online_evaluation(self, tp, fp, fn):
+ self.online_eval_foreground_dc.append(list((2 * tp) / (2 * tp + fp + fn + 1e-8)))
+ self.online_eval_tp.append(list(tp))
+ self.online_eval_fp.append(list(fp))
+ self.online_eval_fn.append(list(fn))
+
+
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions_moreDA.py b/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions_moreDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc0de5702ed27a6eeb92ba3057eecb59d9238d2
--- /dev/null
+++ b/nnunet/training/network_training/competitions_with_custom_Trainers/BraTS2020/nnUNetTrainerV2BraTSRegions_moreDA.py
@@ -0,0 +1,271 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import torch
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.training.data_augmentation.data_augmentation_insaneDA2 import get_insaneDA_augmentation2
+from torch import nn
+
+from nnunet.evaluation.region_based_evaluation import evaluate_regions, get_brats_regions
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
+ default_2D_augmentation_params, get_patch_size
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.loss_functions.dice_loss import DC_and_BCE_loss, get_tp_fp_fn_tn
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.network_training.nnUNet_variants.data_augmentation.nnUNetTrainerV2_DA3 import \
+ nnUNetTrainerV2_DA3_BN
+
+
+class nnUNetTrainerV2BraTSRegions_DA3_BN(nnUNetTrainerV2_DA3_BN):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.regions = get_brats_regions()
+ self.regions_class_order = (1, 2, 3)
+ self.loss = DC_and_BCE_loss({}, {'batch_dice': False, 'do_bg': True, 'smooth': 0})
+
+ def process_plans(self, plans):
+ super().process_plans(plans)
+ """
+ The network has as many outputs as we have regions
+ """
+ self.num_classes = len(self.regions)
+
+ def initialize_network(self):
+ """inference_apply_nonlin to sigmoid"""
+ super().initialize_network()
+ self.network.inference_apply_nonlin = nn.Sigmoid()
+
+ def initialize(self, training=True, force_load_plans=False):
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_insaneDA_augmentation2(
+ self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory,
+ regions=self.regions
+ )
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: int = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ # run brats specific validation
+ output_folder = join(self.output_folder, validation_folder_name)
+ evaluate_regions(output_folder, self.gt_niftis_folder, self.regions)
+
+ def run_online_evaluation(self, output, target):
+ output = output[0]
+ target = target[0]
+ with torch.no_grad():
+ out_sigmoid = torch.sigmoid(output)
+ out_sigmoid = (out_sigmoid > 0.5).float()
+
+ if self.threeD:
+ axes = (0, 2, 3, 4)
+ else:
+ axes = (0, 2, 3)
+
+ tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes)
+
+ tp_hard = tp.detach().cpu().numpy()
+ fp_hard = fp.detach().cpu().numpy()
+ fn_hard = fn.detach().cpu().numpy()
+
+ self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
+ self.online_eval_tp.append(list(tp_hard))
+ self.online_eval_fp.append(list(fp_hard))
+ self.online_eval_fn.append(list(fn_hard))
+
+
+class nnUNetTrainerV2BraTSRegions_DA3(nnUNetTrainerV2BraTSRegions_DA3_BN):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = nn.Sigmoid()
+
+
+class nnUNetTrainerV2BraTSRegions_DA3_BD(nnUNetTrainerV2BraTSRegions_DA3):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_BCE_loss({}, {'batch_dice': True, 'do_bg': True, 'smooth': 0})
+
+
+class nnUNetTrainerV2BraTSRegions_DA3_BN_BD(nnUNetTrainerV2BraTSRegions_DA3_BN):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_BCE_loss({}, {'batch_dice': True, 'do_bg': True, 'smooth': 0})
+
+
+class nnUNetTrainerV2BraTSRegions_DA4_BN(nnUNetTrainerV2BraTSRegions_DA3_BN):
+ def setup_DA_params(self):
+ nnUNetTrainerV2.setup_DA_params(self)
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
+
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ self.data_aug_params['rotation_x'] = (-90. / 360 * 2. * np.pi, 90. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-90. / 360 * 2. * np.pi, 90. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-90. / 360 * 2. * np.pi, 90. / 360 * 2. * np.pi)
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params['selected_seg_channels'] = [0]
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ self.data_aug_params["p_rot"] = 0.3
+
+ self.data_aug_params["scale_range"] = (0.65, 1.6)
+ self.data_aug_params["p_scale"] = 0.3
+ self.data_aug_params["independent_scale_factor_for_each_axis"] = True
+ self.data_aug_params["p_independent_scale_per_axis"] = 0.3
+
+ self.data_aug_params["do_elastic"] = True
+ self.data_aug_params["p_eldef"] = 0.2
+ self.data_aug_params["eldef_deformation_scale"] = (0, 0.25)
+
+ self.data_aug_params["do_additive_brightness"] = True
+ self.data_aug_params["additive_brightness_mu"] = 0
+ self.data_aug_params["additive_brightness_sigma"] = 0.2
+ self.data_aug_params["additive_brightness_p_per_sample"] = 0.3
+ self.data_aug_params["additive_brightness_p_per_channel"] = 0.5
+
+ self.data_aug_params['gamma_range'] = (0.5, 1.6)
+
+ self.data_aug_params['num_cached_per_thread'] = 4
+
+
+class nnUNetTrainerV2BraTSRegions_DA4_BN_BD(nnUNetTrainerV2BraTSRegions_DA4_BN):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_BCE_loss({}, {'batch_dice': True, 'do_bg': True, 'smooth': 0})
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/MMS/__init__.py b/nnunet/training/network_training/competitions_with_custom_Trainers/MMS/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/MMS/nnUNetTrainerV2_MMS.py b/nnunet/training/network_training/competitions_with_custom_Trainers/MMS/nnUNetTrainerV2_MMS.py
new file mode 100755
index 0000000000000000000000000000000000000000..7907fb69542584c044bb901f0348ed8fd6ad0055
--- /dev/null
+++ b/nnunet/training/network_training/competitions_with_custom_Trainers/MMS/nnUNetTrainerV2_MMS.py
@@ -0,0 +1,60 @@
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNet_variants.data_augmentation.nnUNetTrainerV2_insaneDA import \
+ nnUNetTrainerV2_insaneDA
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_MMS(nnUNetTrainerV2_insaneDA):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params["p_rot"] = 0.7
+ self.data_aug_params["p_eldef"] = 0.1
+ self.data_aug_params["p_scale"] = 0.3
+
+ self.data_aug_params["independent_scale_factor_for_each_axis"] = True
+ self.data_aug_params["p_independent_scale_per_axis"] = 0.3
+
+ self.data_aug_params["do_additive_brightness"] = True
+ self.data_aug_params["additive_brightness_mu"] = 0
+ self.data_aug_params["additive_brightness_sigma"] = 0.2
+ self.data_aug_params["additive_brightness_p_per_sample"] = 0.3
+ self.data_aug_params["additive_brightness_p_per_channel"] = 1
+
+ self.data_aug_params["elastic_deform_alpha"] = (0., 300.)
+ self.data_aug_params["elastic_deform_sigma"] = (9., 15.)
+
+ self.data_aug_params['gamma_range'] = (0.5, 1.6)
+
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.BatchNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.BatchNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ """def run_training(self):
+ from batchviewer import view_batch
+ a = next(self.tr_gen)
+ view_batch(a['data'])
+ import IPython;IPython.embed()"""
diff --git a/nnunet/training/network_training/competitions_with_custom_Trainers/__init__.py b/nnunet/training/network_training/competitions_with_custom_Trainers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/network_trainer.py b/nnunet/training/network_training/network_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..abba3067ce782b0978c177df84b614ea9e03bc0f
--- /dev/null
+++ b/nnunet/training/network_training/network_trainer.py
@@ -0,0 +1,737 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from _warnings import warn
+from typing import Tuple
+
+import matplotlib
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from sklearn.model_selection import KFold
+from torch import nn
+from torch.cuda.amp import GradScaler, autocast
+from torch.optim.lr_scheduler import _LRScheduler
+
+matplotlib.use("agg")
+from time import time, sleep
+import torch
+import numpy as np
+from torch.optim import lr_scheduler
+import matplotlib.pyplot as plt
+import sys
+from collections import OrderedDict
+import torch.backends.cudnn as cudnn
+from abc import abstractmethod
+from datetime import datetime
+from tqdm import trange
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+
+
+class NetworkTrainer(object):
+ def __init__(self, deterministic=True, fp16=False):
+ """
+ A generic class that can train almost any neural network (RNNs excluded). It provides basic functionality such
+ as the training loop, tracking of training and validation losses (and the target metric if you implement it)
+ Training can be terminated early if the validation loss (or the target metric if implemented) do not improve
+ anymore. This is based on a moving average (MA) of the loss/metric instead of the raw values to get more smooth
+ results.
+
+ What you need to override:
+ - __init__
+ - initialize
+ - run_online_evaluation (optional)
+ - finish_online_evaluation (optional)
+ - validate
+ - predict_test_case
+ """
+ self.fp16 = fp16
+ self.amp_grad_scaler = None
+
+ if deterministic:
+ np.random.seed(12345)
+ torch.manual_seed(12345)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(12345)
+ cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ else:
+ cudnn.deterministic = False
+ torch.backends.cudnn.benchmark = True
+
+ ################# SET THESE IN self.initialize() ###################################
+ self.network: Tuple[SegmentationNetwork, nn.DataParallel] = None
+ self.optimizer = None
+ self.lr_scheduler = None
+ self.tr_gen = self.val_gen = None
+ self.was_initialized = False
+
+ ################# SET THESE IN INIT ################################################
+ self.output_folder = None
+ self.fold = None
+ self.loss = None
+ self.dataset_directory = None
+
+ ################# SET THESE IN LOAD_DATASET OR DO_SPLIT ############################
+ self.dataset = None # these can be None for inference mode
+ self.dataset_tr = self.dataset_val = None # do not need to be used, they just appear if you are using the suggested load_dataset_and_do_split
+
+ ################# THESE DO NOT NECESSARILY NEED TO BE MODIFIED #####################
+ self.patience = 50
+ self.val_eval_criterion_alpha = 0.9 # alpha * old + (1-alpha) * new
+ # if this is too low then the moving average will be too noisy and the training may terminate early. If it is
+ # too high the training will take forever
+ self.train_loss_MA_alpha = 0.93 # alpha * old + (1-alpha) * new
+ self.train_loss_MA_eps = 5e-4 # new MA must be at least this much better (smaller)
+ self.max_num_epochs = 1000
+ self.num_batches_per_epoch = 250
+ self.num_val_batches_per_epoch = 50
+ self.also_val_in_tr_mode = False
+ self.lr_threshold = 1e-6 # the network will not terminate training if the lr is still above this threshold
+
+ ################# LEAVE THESE ALONE ################################################
+ self.val_eval_criterion_MA = None
+ self.train_loss_MA = None
+ self.best_val_eval_criterion_MA = None
+ self.best_MA_tr_loss_for_patience = None
+ self.best_epoch_based_on_MA_tr_loss = None
+ self.all_tr_losses = []
+ self.all_val_losses = []
+ self.all_val_losses_tr_mode = []
+ self.all_val_eval_metrics = [] # does not have to be used
+ self.epoch = 0
+ self.log_file = None
+ self.deterministic = deterministic
+
+ self.use_progress_bar = False
+ if 'nnunet_use_progress_bar' in os.environ.keys():
+ self.use_progress_bar = bool(int(os.environ['nnunet_use_progress_bar']))
+
+ ################# Settings for saving checkpoints ##################################
+ self.save_every = 50
+ self.save_latest_only = True # if false it will not store/overwrite _latest but separate files each
+ # time an intermediate checkpoint is created
+ self.save_intermediate_checkpoints = True # whether or not to save checkpoint_latest
+ self.save_best_checkpoint = True # whether or not to save the best checkpoint according to self.best_val_eval_criterion_MA
+ self.save_final_checkpoint = True # whether or not to save the final checkpoint
+
+ @abstractmethod
+ def initialize(self, training=True):
+ """
+ create self.output_folder
+
+ modify self.output_folder if you are doing cross-validation (one folder per fold)
+
+ set self.tr_gen and self.val_gen
+
+ call self.initialize_network and self.initialize_optimizer_and_scheduler (important!)
+
+ finally set self.was_initialized to True
+ :param training:
+ :return:
+ """
+
+ @abstractmethod
+ def load_dataset(self):
+ pass
+
+ def do_split(self):
+ """
+ This is a suggestion for if your dataset is a dictionary (my personal standard)
+ :return:
+ """
+ splits_file = join(self.dataset_directory, "splits_final.pkl")
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new split...")
+ splits = []
+ all_keys_sorted = np.sort(list(self.dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append(OrderedDict())
+ splits[-1]['train'] = train_keys
+ splits[-1]['val'] = test_keys
+ save_pickle(splits, splits_file)
+
+ splits = load_pickle(splits_file)
+
+ if self.fold == "all":
+ tr_keys = val_keys = list(self.dataset.keys())
+ else:
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+
+ tr_keys.sort()
+ val_keys.sort()
+
+ self.dataset_tr = OrderedDict()
+ for i in tr_keys:
+ self.dataset_tr[i] = self.dataset[i]
+
+ self.dataset_val = OrderedDict()
+ for i in val_keys:
+ self.dataset_val[i] = self.dataset[i]
+
+ def plot_progress(self):
+ """
+ Should probably by improved
+ :return:
+ """
+ try:
+ font = {'weight': 'normal',
+ 'size': 18}
+
+ matplotlib.rc('font', **font)
+
+ fig = plt.figure(figsize=(30, 24))
+ ax = fig.add_subplot(111)
+ ax2 = ax.twinx()
+
+ x_values = list(range(self.epoch + 1))
+
+ ax.plot(x_values, self.all_tr_losses, color='b', ls='-', label="loss_tr")
+
+ ax.plot(x_values, self.all_val_losses, color='r', ls='-', label="loss_val, train=False")
+
+ if len(self.all_val_losses_tr_mode) > 0:
+ ax.plot(x_values, self.all_val_losses_tr_mode, color='g', ls='-', label="loss_val, train=True")
+ if len(self.all_val_eval_metrics) == len(x_values):
+ ax2.plot(x_values, self.all_val_eval_metrics, color='g', ls='--', label="evaluation metric")
+
+ ax.set_xlabel("epoch")
+ ax.set_ylabel("loss")
+ ax2.set_ylabel("evaluation metric")
+ ax.legend()
+ ax2.legend(loc=9)
+
+ fig.savefig(join(self.output_folder, "progress.png"))
+ plt.close()
+ except IOError:
+ self.print_to_log_file("failed to plot: ", sys.exc_info())
+
+ def print_to_log_file(self, *args, also_print_to_console=True, add_timestamp=True):
+
+ timestamp = time()
+ dt_object = datetime.fromtimestamp(timestamp)
+
+ if add_timestamp:
+ args = ("%s:" % dt_object, *args)
+
+ if self.log_file is None:
+ maybe_mkdir_p(self.output_folder)
+ timestamp = datetime.now()
+ self.log_file = join(self.output_folder, "training_log_%d_%d_%d_%02.0d_%02.0d_%02.0d.txt" %
+ (timestamp.year, timestamp.month, timestamp.day, timestamp.hour, timestamp.minute,
+ timestamp.second))
+ with open(self.log_file, 'w') as f:
+ f.write("Starting... \n")
+ successful = False
+ max_attempts = 5
+ ctr = 0
+ while not successful and ctr < max_attempts:
+ try:
+ with open(self.log_file, 'a+') as f:
+ for a in args:
+ f.write(str(a))
+ f.write(" ")
+ f.write("\n")
+ successful = True
+ except IOError:
+ print("%s: failed to log: " % datetime.fromtimestamp(timestamp), sys.exc_info())
+ sleep(0.5)
+ ctr += 1
+ if also_print_to_console:
+ print(*args)
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ start_time = time()
+ state_dict = self.network.state_dict()
+ for key in state_dict.keys():
+ state_dict[key] = state_dict[key].cpu()
+ lr_sched_state_dct = None
+ if self.lr_scheduler is not None and hasattr(self.lr_scheduler,
+ 'state_dict'): # not isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
+ lr_sched_state_dct = self.lr_scheduler.state_dict()
+ # WTF is this!?
+ # for key in lr_sched_state_dct.keys():
+ # lr_sched_state_dct[key] = lr_sched_state_dct[key]
+ if save_optimizer:
+ optimizer_state_dict = self.optimizer.state_dict()
+ else:
+ optimizer_state_dict = None
+
+ self.print_to_log_file("saving checkpoint...")
+ save_this = {
+ 'epoch': self.epoch + 1,
+ 'state_dict': state_dict,
+ 'optimizer_state_dict': optimizer_state_dict,
+ 'lr_scheduler_state_dict': lr_sched_state_dct,
+ 'plot_stuff': (self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode,
+ self.all_val_eval_metrics),
+ 'best_stuff' : (self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA)}
+ if self.amp_grad_scaler is not None:
+ save_this['amp_grad_scaler'] = self.amp_grad_scaler.state_dict()
+
+ torch.save(save_this, fname)
+ self.print_to_log_file("done, saving took %.2f seconds" % (time() - start_time))
+
+ def load_best_checkpoint(self, train=True):
+ if self.fold is None:
+ raise RuntimeError("Cannot load best checkpoint if self.fold is None")
+ if isfile(join(self.output_folder, "model_best.model")):
+ self.load_checkpoint(join(self.output_folder, "model_best.model"), train=train)
+ else:
+ self.print_to_log_file("WARNING! model_best.model does not exist! Cannot load best checkpoint. Falling "
+ "back to load_latest_checkpoint")
+ self.load_latest_checkpoint(train)
+
+ def load_latest_checkpoint(self, train=True):
+ if isfile(join(self.output_folder, "model_final_checkpoint.model")):
+ return self.load_checkpoint(join(self.output_folder, "model_final_checkpoint.model"), train=train)
+ if isfile(join(self.output_folder, "model_latest.model")):
+ return self.load_checkpoint(join(self.output_folder, "model_latest.model"), train=train)
+ if isfile(join(self.output_folder, "model_best.model")):
+ return self.load_best_checkpoint(train)
+ raise RuntimeError("No checkpoint found")
+
+ def load_final_checkpoint(self, train=False):
+ filename = join(self.output_folder, "model_final_checkpoint.model")
+ if not isfile(filename):
+ raise RuntimeError("Final checkpoint not found. Expected: %s. Please finish the training first." % filename)
+ return self.load_checkpoint(filename, train=train)
+
+ def load_checkpoint(self, fname, train=True):
+ self.print_to_log_file("loading checkpoint", fname, "train=", train)
+ if not self.was_initialized:
+ self.initialize(train)
+ # saved_model = torch.load(fname, map_location=torch.device('cuda', torch.cuda.current_device()))
+ saved_model = torch.load(fname, map_location=torch.device('cpu'))
+ self.load_checkpoint_ram(saved_model, train)
+
+ @abstractmethod
+ def initialize_network(self):
+ """
+ initialize self.network here
+ :return:
+ """
+ pass
+
+ @abstractmethod
+ def initialize_optimizer_and_scheduler(self):
+ """
+ initialize self.optimizer and self.lr_scheduler (if applicable) here
+ :return:
+ """
+ pass
+
+ def load_checkpoint_ram(self, checkpoint, train=True):
+ """
+ used for if the checkpoint is already in ram
+ :param checkpoint:
+ :param train:
+ :return:
+ """
+ if not self.was_initialized:
+ self.initialize(train)
+
+ new_state_dict = OrderedDict()
+ curr_state_dict_keys = list(self.network.state_dict().keys())
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ for k, value in checkpoint['state_dict'].items():
+ key = k
+ if key not in curr_state_dict_keys and key.startswith('module.'):
+ key = key[7:]
+ new_state_dict[key] = value
+
+ if self.fp16:
+ self._maybe_init_amp()
+ if train:
+ if 'amp_grad_scaler' in checkpoint.keys():
+ self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
+
+ self.network.load_state_dict(new_state_dict)
+ self.epoch = checkpoint['epoch']
+ if train:
+ optimizer_state_dict = checkpoint['optimizer_state_dict']
+ if optimizer_state_dict is not None:
+ self.optimizer.load_state_dict(optimizer_state_dict)
+
+ if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
+ 'lr_scheduler_state_dict'] is not None:
+ self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
+
+ if issubclass(self.lr_scheduler.__class__, _LRScheduler):
+ self.lr_scheduler.step(self.epoch)
+
+ self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
+ 'plot_stuff']
+
+ # load best loss (if present)
+ if 'best_stuff' in checkpoint.keys():
+ self.best_epoch_based_on_MA_tr_loss, self.best_MA_tr_loss_for_patience, self.best_val_eval_criterion_MA = checkpoint[
+ 'best_stuff']
+
+ # after the training is done, the epoch is incremented one more time in my old code. This results in
+ # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
+ # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
+ if self.epoch != len(self.all_tr_losses):
+ self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
+ "due to an old bug and should only appear when you are loading old models. New "
+ "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
+ self.epoch = len(self.all_tr_losses)
+ self.all_tr_losses = self.all_tr_losses[:self.epoch]
+ self.all_val_losses = self.all_val_losses[:self.epoch]
+ self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
+ self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
+
+ self._maybe_init_amp()
+
+ def _maybe_init_amp(self):
+ if self.fp16 and self.amp_grad_scaler is None:
+ self.amp_grad_scaler = GradScaler()
+
+ def plot_network_architecture(self):
+ """
+ can be implemented (see nnUNetTrainer) but does not have to. Not implemented here because it imposes stronger
+ assumptions on the presence of class variables
+ :return:
+ """
+ pass
+
+ def run_training(self):
+ if not torch.cuda.is_available():
+ self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
+
+ _ = self.tr_gen.next()
+ _ = self.val_gen.next()
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ self._maybe_init_amp()
+
+ maybe_mkdir_p(self.output_folder)
+ self.plot_network_architecture()
+
+ if cudnn.benchmark and cudnn.deterministic:
+ warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
+ "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
+ "If you want deterministic then set benchmark=False")
+
+ if not self.was_initialized:
+ self.initialize(True)
+
+ while self.epoch < self.max_num_epochs:
+ self.print_to_log_file("\nepoch: ", self.epoch)
+ epoch_start_time = time()
+ train_losses_epoch = []
+
+ # train one epoch
+ self.network.train()
+
+ if self.use_progress_bar:
+ with trange(self.num_batches_per_epoch) as tbar:
+ for b in tbar:
+ tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
+
+ l = self.run_iteration(self.tr_gen, True)
+
+ tbar.set_postfix(loss=l)
+ train_losses_epoch.append(l)
+ else:
+ for _ in range(self.num_batches_per_epoch):
+ l = self.run_iteration(self.tr_gen, True)
+ train_losses_epoch.append(l)
+
+ self.all_tr_losses.append(np.mean(train_losses_epoch))
+ self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
+
+ with torch.no_grad():
+ # validation with train=False
+ self.network.eval()
+ val_losses = []
+ for b in range(self.num_val_batches_per_epoch):
+ l = self.run_iteration(self.val_gen, False, True)
+ val_losses.append(l)
+ self.all_val_losses.append(np.mean(val_losses))
+ self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
+
+ if self.also_val_in_tr_mode:
+ self.network.train()
+ # validation with train=True
+ val_losses = []
+ for b in range(self.num_val_batches_per_epoch):
+ l = self.run_iteration(self.val_gen, False)
+ val_losses.append(l)
+ self.all_val_losses_tr_mode.append(np.mean(val_losses))
+ self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
+
+ self.update_train_loss_MA() # needed for lr scheduler and stopping of training
+
+ continue_training = self.on_epoch_end()
+
+ epoch_end_time = time()
+
+ if not continue_training:
+ # allows for early stopping
+ break
+
+ self.epoch += 1
+ self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
+
+ self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
+
+ if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
+ # now we can delete latest as it will be identical with final
+ if isfile(join(self.output_folder, "model_latest.model")):
+ os.remove(join(self.output_folder, "model_latest.model"))
+ if isfile(join(self.output_folder, "model_latest.model.pkl")):
+ os.remove(join(self.output_folder, "model_latest.model.pkl"))
+
+ def maybe_update_lr(self):
+ # maybe update learning rate
+ if self.lr_scheduler is not None:
+ assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
+
+ if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
+ # lr scheduler is updated with moving average val loss. should be more robust
+ self.lr_scheduler.step(self.train_loss_MA)
+ else:
+ self.lr_scheduler.step(self.epoch + 1)
+ self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
+
+ def maybe_save_checkpoint(self):
+ """
+ Saves a checkpoint every save_ever epochs.
+ :return:
+ """
+ if self.save_intermediate_checkpoints and (self.epoch % self.save_every == (self.save_every - 1)):
+ self.print_to_log_file("saving scheduled checkpoint file...")
+ if not self.save_latest_only:
+ self.save_checkpoint(join(self.output_folder, "model_ep_%03.0d.model" % (self.epoch + 1)))
+ self.save_checkpoint(join(self.output_folder, "model_latest.model"))
+ self.print_to_log_file("done")
+
+ def update_eval_criterion_MA(self):
+ """
+ If self.all_val_eval_metrics is unused (len=0) then we fall back to using -self.all_val_losses for the MA to determine early stopping
+ (not a minimization, but a maximization of a metric and therefore the - in the latter case)
+ :return:
+ """
+ if self.val_eval_criterion_MA is None:
+ if len(self.all_val_eval_metrics) == 0:
+ self.val_eval_criterion_MA = - self.all_val_losses[-1]
+ else:
+ self.val_eval_criterion_MA = self.all_val_eval_metrics[-1]
+ else:
+ if len(self.all_val_eval_metrics) == 0:
+ """
+ We here use alpha * old - (1 - alpha) * new because new in this case is the vlaidation loss and lower
+ is better, so we need to negate it.
+ """
+ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA - (
+ 1 - self.val_eval_criterion_alpha) * \
+ self.all_val_losses[-1]
+ else:
+ self.val_eval_criterion_MA = self.val_eval_criterion_alpha * self.val_eval_criterion_MA + (
+ 1 - self.val_eval_criterion_alpha) * \
+ self.all_val_eval_metrics[-1]
+
+ def manage_patience(self):
+ # update patience
+ continue_training = True
+ if self.patience is not None:
+ # if best_MA_tr_loss_for_patience and best_epoch_based_on_MA_tr_loss were not yet initialized,
+ # initialize them
+ if self.best_MA_tr_loss_for_patience is None:
+ self.best_MA_tr_loss_for_patience = self.train_loss_MA
+
+ if self.best_epoch_based_on_MA_tr_loss is None:
+ self.best_epoch_based_on_MA_tr_loss = self.epoch
+
+ if self.best_val_eval_criterion_MA is None:
+ self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
+
+ # check if the current epoch is the best one according to moving average of validation criterion. If so
+ # then save 'best' model
+ # Do not use this for validation. This is intended for test set prediction only.
+ #self.print_to_log_file("current best_val_eval_criterion_MA is %.4f0" % self.best_val_eval_criterion_MA)
+ #self.print_to_log_file("current val_eval_criterion_MA is %.4f" % self.val_eval_criterion_MA)
+
+ if self.val_eval_criterion_MA > self.best_val_eval_criterion_MA:
+ self.best_val_eval_criterion_MA = self.val_eval_criterion_MA
+ #self.print_to_log_file("saving best epoch checkpoint...")
+ if self.save_best_checkpoint: self.save_checkpoint(join(self.output_folder, "model_best.model"))
+
+ # Now see if the moving average of the train loss has improved. If yes then reset patience, else
+ # increase patience
+ if self.train_loss_MA + self.train_loss_MA_eps < self.best_MA_tr_loss_for_patience:
+ self.best_MA_tr_loss_for_patience = self.train_loss_MA
+ self.best_epoch_based_on_MA_tr_loss = self.epoch
+ #self.print_to_log_file("New best epoch (train loss MA): %03.4f" % self.best_MA_tr_loss_for_patience)
+ else:
+ pass
+ #self.print_to_log_file("No improvement: current train MA %03.4f, best: %03.4f, eps is %03.4f" %
+ # (self.train_loss_MA, self.best_MA_tr_loss_for_patience, self.train_loss_MA_eps))
+
+ # if patience has reached its maximum then finish training (provided lr is low enough)
+ if self.epoch - self.best_epoch_based_on_MA_tr_loss > self.patience:
+ if self.optimizer.param_groups[0]['lr'] > self.lr_threshold:
+ #self.print_to_log_file("My patience ended, but I believe I need more time (lr > 1e-6)")
+ self.best_epoch_based_on_MA_tr_loss = self.epoch - self.patience // 2
+ else:
+ #self.print_to_log_file("My patience ended")
+ continue_training = False
+ else:
+ pass
+ #self.print_to_log_file(
+ # "Patience: %d/%d" % (self.epoch - self.best_epoch_based_on_MA_tr_loss, self.patience))
+
+ return continue_training
+
+ def on_epoch_end(self):
+ self.finish_online_evaluation() # does not have to do anything, but can be used to update self.all_val_eval_
+ # metrics
+
+ self.plot_progress()
+
+ self.maybe_update_lr()
+
+ self.maybe_save_checkpoint()
+
+ self.update_eval_criterion_MA()
+
+ continue_training = self.manage_patience()
+ return continue_training
+
+ def update_train_loss_MA(self):
+ if self.train_loss_MA is None:
+ self.train_loss_MA = self.all_tr_losses[-1]
+ else:
+ self.train_loss_MA = self.train_loss_MA_alpha * self.train_loss_MA + (1 - self.train_loss_MA_alpha) * \
+ self.all_tr_losses[-1]
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data)
+ target = to_cuda(target)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def run_online_evaluation(self, *args, **kwargs):
+ """
+ Can be implemented, does not have to
+ :param output_torch:
+ :param target_npy:
+ :return:
+ """
+ pass
+
+ def finish_online_evaluation(self):
+ """
+ Can be implemented, does not have to
+ :return:
+ """
+ pass
+
+ @abstractmethod
+ def validate(self, *args, **kwargs):
+ pass
+
+ def find_lr(self, num_iters=1000, init_value=1e-6, final_value=10., beta=0.98):
+ """
+ stolen and adapted from here: https://sgugger.github.io/how-do-you-find-a-good-learning-rate.html
+ :param num_iters:
+ :param init_value:
+ :param final_value:
+ :param beta:
+ :return:
+ """
+ import math
+ self._maybe_init_amp()
+ mult = (final_value / init_value) ** (1 / num_iters)
+ lr = init_value
+ self.optimizer.param_groups[0]['lr'] = lr
+ avg_loss = 0.
+ best_loss = 0.
+ losses = []
+ log_lrs = []
+
+ for batch_num in range(1, num_iters + 1):
+ # +1 because this one here is not designed to have negative loss...
+ loss = self.run_iteration(self.tr_gen, do_backprop=True, run_online_evaluation=False).data.item() + 1
+
+ # Compute the smoothed loss
+ avg_loss = beta * avg_loss + (1 - beta) * loss
+ smoothed_loss = avg_loss / (1 - beta ** batch_num)
+
+ # Stop if the loss is exploding
+ if batch_num > 1 and smoothed_loss > 4 * best_loss:
+ break
+
+ # Record the best loss
+ if smoothed_loss < best_loss or batch_num == 1:
+ best_loss = smoothed_loss
+
+ # Store the values
+ losses.append(smoothed_loss)
+ log_lrs.append(math.log10(lr))
+
+ # Update the lr for the next step
+ lr *= mult
+ self.optimizer.param_groups[0]['lr'] = lr
+
+ import matplotlib.pyplot as plt
+ lrs = [10 ** i for i in log_lrs]
+ fig = plt.figure()
+ plt.xscale('log')
+ plt.plot(lrs[10:-5], losses[10:-5])
+ plt.savefig(join(self.output_folder, "lr_finder.png"))
+ plt.close()
+ return log_lrs, losses
diff --git a/nnunet/training/network_training/nnUNetTrainer.py b/nnunet/training/network_training/nnUNetTrainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..aba016cb29b92a79f9467cfcaae2650837e15581
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainer.py
@@ -0,0 +1,734 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import shutil
+from collections import OrderedDict
+from multiprocessing import Pool
+from time import sleep
+from typing import Tuple, List
+
+import matplotlib
+import numpy as np
+import torch
+from batchgenerators.utilities.file_and_folder_operations import *
+from torch import nn
+from torch.optim import lr_scheduler
+
+import nnunet
+from nnunet.configuration import default_num_threads
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.postprocessing.connected_components import determine_postprocessing
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
+ default_2D_augmentation_params, get_default_augmentation, get_patch_size
+from nnunet.training.dataloading.dataset_loading import load_dataset, DataLoader3D, DataLoader2D, unpack_dataset
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from nnunet.training.network_training.network_trainer import NetworkTrainer
+from nnunet.utilities.nd_softmax import softmax_helper
+from nnunet.utilities.tensor_utilities import sum_tensor
+
+matplotlib.use("agg")
+
+
+class nnUNetTrainer(NetworkTrainer):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ """
+ :param deterministic:
+ :param fold: can be either [0 ... 5) for cross-validation, 'all' to train on all available training data or
+ None if you wish to load some checkpoint and do inference only
+ :param plans_file: the pkl file generated by preprocessing. This file will determine all design choices
+ :param subfolder_with_preprocessed_data: must be a subfolder of dataset_directory (just the name of the folder,
+ not the entire path). This is where the preprocessed data lies that will be used for network training. We made
+ this explicitly available so that differently preprocessed data can coexist and the user can choose what to use.
+ Can be None if you are doing inference only.
+ :param output_folder: where to store parameters, plot progress and to the validation
+ :param dataset_directory: the parent directory in which the preprocessed Task data is stored. This is required
+ because the split information is stored in this directory. For running prediction only this input is not
+ required and may be set to None
+ :param batch_dice: compute dice loss for each sample and average over all samples in the batch or pretend the
+ batch is a pseudo volume?
+ :param stage: The plans file may contain several stages (used for lowres / highres / pyramid). Stage must be
+ specified for training:
+ if stage 1 exists then stage 1 is the high resolution stage, otherwise it's 0
+ :param unpack_data: if False, npz preprocessed data will not be unpacked to npy. This consumes less space but
+ is considerably slower! Running unpack_data=False with 2d should never be done!
+
+ IMPORTANT: If you inherit from nnUNetTrainer and the init args change then you need to redefine self.init_args
+ in your init accordingly. Otherwise checkpoints won't load properly!
+ """
+ super(nnUNetTrainer, self).__init__(deterministic, fp16)
+ self.unpack_data = unpack_data
+ self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ # set through arguments from init
+ self.stage = stage
+ self.experiment_name = self.__class__.__name__
+ self.plans_file = plans_file
+ self.output_folder = output_folder
+ self.dataset_directory = dataset_directory
+ self.output_folder_base = self.output_folder
+ self.fold = fold
+
+ self.plans = None
+
+ # if we are running inference only then the self.dataset_directory is set (due to checkpoint loading) but it
+ # irrelevant
+ if self.dataset_directory is not None and isdir(self.dataset_directory):
+ self.gt_niftis_folder = join(self.dataset_directory, "gt_segmentations")
+ else:
+ self.gt_niftis_folder = None
+
+ self.folder_with_preprocessed_data = None
+
+ # set in self.initialize()
+
+ self.dl_tr = self.dl_val = None
+ self.num_input_channels = self.num_classes = self.net_pool_per_axis = self.patch_size = self.batch_size = \
+ self.threeD = self.base_num_features = self.intensity_properties = self.normalization_schemes = \
+ self.net_num_pool_op_kernel_sizes = self.net_conv_kernel_sizes = None # loaded automatically from plans_file
+ self.basic_generator_patch_size = self.data_aug_params = self.transpose_forward = self.transpose_backward = None
+
+ self.batch_dice = batch_dice
+ self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ self.online_eval_foreground_dc = []
+ self.online_eval_tp = []
+ self.online_eval_fp = []
+ self.online_eval_fn = []
+
+ self.classes = self.do_dummy_2D_aug = self.use_mask_for_norm = self.only_keep_largest_connected_component = \
+ self.min_region_size_per_class = self.min_size_per_class = None
+
+ self.inference_pad_border_mode = "constant"
+ self.inference_pad_kwargs = {'constant_values': 0}
+
+ self.update_fold(fold)
+ self.pad_all_sides = None
+
+ self.lr_scheduler_eps = 1e-3
+ self.lr_scheduler_patience = 30
+ self.initial_lr = 3e-4
+ self.weight_decay = 3e-5
+
+ self.oversample_foreground_percent = 0.33
+
+ self.conv_per_stage = None
+ self.regions_class_order = None
+
+ def update_fold(self, fold):
+ """
+ used to swap between folds for inference (ensemble of models from cross-validation)
+ DO NOT USE DURING TRAINING AS THIS WILL NOT UPDATE THE DATASET SPLIT AND THE DATA AUGMENTATION GENERATORS
+ :param fold:
+ :return:
+ """
+ if fold is not None:
+ if isinstance(fold, str):
+ assert fold == "all", "if self.fold is a string then it must be \'all\'"
+ if self.output_folder.endswith("%s" % str(self.fold)):
+ self.output_folder = self.output_folder_base
+ self.output_folder = join(self.output_folder, "%s" % str(fold))
+ else:
+ if self.output_folder.endswith("fold_%s" % str(self.fold)):
+ self.output_folder = self.output_folder_base
+ self.output_folder = join(self.output_folder, "fold_%s" % str(fold))
+ self.fold = fold
+
+ def setup_DA_params(self):
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params['selected_seg_channels'] = [0]
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ For prediction of test cases just set training=False, this will prevent loading of training data and
+ training batchgenerator initialization
+ :param training:
+ :return:
+ """
+
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ if training:
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ self.print_to_log_file("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ self.print_to_log_file("done")
+ else:
+ self.print_to_log_file(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+ self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+ # assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ self.was_initialized = True
+
+ def initialize_network(self):
+ """
+ This is specific to the U-Net and must be adapted for other network architectures
+ :return:
+ """
+ # self.print_to_log_file(self.net_num_pool_op_kernel_sizes)
+ # self.print_to_log_file(self.net_conv_kernel_sizes)
+
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes, net_numpool,
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ self.network.inference_apply_nonlin = softmax_helper
+
+ if torch.cuda.is_available():
+ self.network.cuda()
+
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ amsgrad=True)
+ self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
+ patience=self.lr_scheduler_patience,
+ verbose=True, threshold=self.lr_scheduler_eps,
+ threshold_mode="abs")
+
+ def plot_network_architecture(self):
+ try:
+ from batchgenerators.utilities.file_and_folder_operations import join
+ import hiddenlayer as hl
+ if torch.cuda.is_available():
+ g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)).cuda(),
+ transforms=None)
+ else:
+ g = hl.build_graph(self.network, torch.rand((1, self.num_input_channels, *self.patch_size)),
+ transforms=None)
+ g.save(join(self.output_folder, "network_architecture.pdf"))
+ del g
+ except Exception as e:
+ self.print_to_log_file("Unable to plot network architecture:")
+ self.print_to_log_file(e)
+
+ self.print_to_log_file("\nprinting the network instead:\n")
+ self.print_to_log_file(self.network)
+ self.print_to_log_file("\n")
+ finally:
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ def save_debug_information(self):
+ # saving some debug information
+ dct = OrderedDict()
+ for k in self.__dir__():
+ if not k.startswith("__"):
+ if not callable(getattr(self, k)):
+ dct[k] = str(getattr(self, k))
+ del dct['plans']
+ del dct['intensity_properties']
+ del dct['dataset']
+ del dct['dataset_tr']
+ del dct['dataset_val']
+ save_json(dct, join(self.output_folder, "debug.json"))
+
+ import shutil
+
+ shutil.copy(self.plans_file, join(self.output_folder_base, "plans.pkl"))
+
+ def run_training(self):
+ self.save_debug_information()
+ super(nnUNetTrainer, self).run_training()
+
+ def load_plans_file(self):
+ """
+ This is what actually configures the entire experiment. The plans file is generated by experiment planning
+ :return:
+ """
+ self.plans = load_pickle(self.plans_file)
+
+ def process_plans(self, plans):
+ if self.stage is None:
+ assert len(list(plans['plans_per_stage'].keys())) == 1, \
+ "If self.stage is None then there can be only one stage in the plans file. That seems to not be the " \
+ "case. Please specify which stage of the cascade must be trained"
+ self.stage = list(plans['plans_per_stage'].keys())[0]
+ self.plans = plans
+
+ stage_plans = self.plans['plans_per_stage'][self.stage]
+ self.batch_size = stage_plans['batch_size']
+ self.net_pool_per_axis = stage_plans['num_pool_per_axis']
+ self.patch_size = np.array(stage_plans['patch_size']).astype(int)
+ self.do_dummy_2D_aug = stage_plans['do_dummy_2D_data_aug']
+
+ if 'pool_op_kernel_sizes' not in stage_plans.keys():
+ assert 'num_pool_per_axis' in stage_plans.keys()
+ self.print_to_log_file("WARNING! old plans file with missing pool_op_kernel_sizes. Attempting to fix it...")
+ self.net_num_pool_op_kernel_sizes = []
+ for i in range(max(self.net_pool_per_axis)):
+ curr = []
+ for j in self.net_pool_per_axis:
+ if (max(self.net_pool_per_axis) - j) <= i:
+ curr.append(2)
+ else:
+ curr.append(1)
+ self.net_num_pool_op_kernel_sizes.append(curr)
+ else:
+ self.net_num_pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
+
+ if 'conv_kernel_sizes' not in stage_plans.keys():
+ self.print_to_log_file("WARNING! old plans file with missing conv_kernel_sizes. Attempting to fix it...")
+ self.net_conv_kernel_sizes = [[3] * len(self.net_pool_per_axis)] * (max(self.net_pool_per_axis) + 1)
+ else:
+ self.net_conv_kernel_sizes = stage_plans['conv_kernel_sizes']
+
+ self.pad_all_sides = None # self.patch_size
+ self.intensity_properties = plans['dataset_properties']['intensityproperties']
+ self.normalization_schemes = plans['normalization_schemes']
+ self.base_num_features = plans['base_num_features']
+ self.num_input_channels = plans['num_modalities']
+ self.num_classes = plans['num_classes'] + 1 # background is no longer in num_classes
+ self.classes = plans['all_classes']
+ self.use_mask_for_norm = plans['use_mask_for_norm']
+ self.only_keep_largest_connected_component = plans['keep_only_largest_region']
+ self.min_region_size_per_class = plans['min_region_size_per_class']
+ self.min_size_per_class = None # DONT USE THIS. plans['min_size_per_class']
+
+ if plans.get('transpose_forward') is None or plans.get('transpose_backward') is None:
+ print("WARNING! You seem to have data that was preprocessed with a previous version of nnU-Net. "
+ "You should rerun preprocessing. We will proceed and assume that both transpose_foward "
+ "and transpose_backward are [0, 1, 2]. If that is not correct then weird things will happen!")
+ plans['transpose_forward'] = [0, 1, 2]
+ plans['transpose_backward'] = [0, 1, 2]
+ self.transpose_forward = plans['transpose_forward']
+ self.transpose_backward = plans['transpose_backward']
+
+ if len(self.patch_size) == 2:
+ self.threeD = False
+ elif len(self.patch_size) == 3:
+ self.threeD = True
+ else:
+ raise RuntimeError("invalid patch size in plans file: %s" % str(self.patch_size))
+
+ if "conv_per_stage" in plans.keys(): # this ha sbeen added to the plans only recently
+ self.conv_per_stage = plans['conv_per_stage']
+ else:
+ self.conv_per_stage = 2
+
+ def load_dataset(self):
+ self.dataset = load_dataset(self.folder_with_preprocessed_data)
+
+ def get_basic_generators(self):
+ self.load_dataset()
+ self.do_split()
+
+ if self.threeD:
+ dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
+ False, oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
+ dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
+ else:
+ dl_tr = DataLoader2D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
+ dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides, memmap_mode='r')
+ return dl_tr, dl_val
+
+ def preprocess_patient(self, input_files):
+ """
+ Used to predict new unseen data. Not used for the preprocessing of the training/test data
+ :param input_files:
+ :return:
+ """
+ from nnunet.training.model_restore import recursive_find_python_class
+ preprocessor_name = self.plans.get('preprocessor_name')
+ if preprocessor_name is None:
+ if self.threeD:
+ preprocessor_name = "GenericPreprocessor"
+ else:
+ preprocessor_name = "PreprocessorFor2D"
+
+ print("using preprocessor", preprocessor_name)
+ preprocessor_class = recursive_find_python_class([join(nnunet.__path__[0], "preprocessing")],
+ preprocessor_name,
+ current_module="nnunet.preprocessing")
+ assert preprocessor_class is not None, "Could not find preprocessor %s in nnunet.preprocessing" % \
+ preprocessor_name
+ preprocessor = preprocessor_class(self.normalization_schemes, self.use_mask_for_norm,
+ self.transpose_forward, self.intensity_properties)
+
+ d, s, properties = preprocessor.preprocess_test_case(input_files,
+ self.plans['plans_per_stage'][self.stage][
+ 'current_spacing'])
+ return d, s, properties
+
+ def preprocess_predict_nifti(self, input_files: List[str], output_file: str = None,
+ softmax_ouput_file: str = None, mixed_precision: bool = True) -> None:
+ """
+ Use this to predict new data
+ :param input_files:
+ :param output_file:
+ :param softmax_ouput_file:
+ :param mixed_precision:
+ :return:
+ """
+ print("preprocessing...")
+ d, s, properties = self.preprocess_patient(input_files)
+ print("predicting...")
+ pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
+ mirror_axes=self.data_aug_params['mirror_axes'],
+ use_sliding_window=True, step_size=0.5,
+ use_gaussian=True, pad_border_mode='constant',
+ pad_kwargs={'constant_values': 0},
+ verbose=True, all_in_gpu=False,
+ mixed_precision=mixed_precision)[1]
+ pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ if 'segmentation_export_params' in self.plans.keys():
+ force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+
+ print("resampling to original spacing and nifti export...")
+ save_segmentation_nifti_from_softmax(pred, output_file, properties, interpolation_order,
+ self.regions_class_order, None, None, softmax_ouput_file,
+ None, force_separate_z=force_separate_z,
+ interpolation_order_z=interpolation_order_z)
+ print("done")
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision: bool = True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ :param data:
+ :param do_mirroring:
+ :param mirror_axes:
+ :param use_sliding_window:
+ :param step_size:
+ :param use_gaussian:
+ :param pad_border_mode:
+ :param pad_kwargs:
+ :param all_in_gpu:
+ :param verbose:
+ :return:
+ """
+ if pad_border_mode == 'constant' and pad_kwargs is None:
+ pad_kwargs = {'constant_values': 0}
+
+ if do_mirroring and mirror_axes is None:
+ mirror_axes = self.data_aug_params['mirror_axes']
+
+ if do_mirroring:
+ assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
+ "was done without mirroring"
+
+ valid = list((SegmentationNetwork, nn.DataParallel))
+ assert isinstance(self.network, tuple(valid))
+
+ current_mode = self.network.training
+ self.network.eval()
+ ret = self.network.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window, step_size=step_size,
+ patch_size=self.patch_size, regions_class_order=self.regions_class_order,
+ use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
+ pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
+ mixed_precision=mixed_precision)
+ self.network.train(current_mode)
+ return ret
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ """
+ if debug=True then the temporary files generated for postprocessing determination will be kept
+ """
+
+ current_mode = self.network.training
+ self.network.eval()
+
+ assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
+ if self.dataset_val is None:
+ self.load_dataset()
+ self.do_split()
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in self.plans.keys():
+ force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ # predictions as they come from the network go here
+ output_folder = join(self.output_folder, validation_folder_name)
+ maybe_mkdir_p(output_folder)
+ # this is for debug purposes
+ my_input_args = {'do_mirroring': do_mirroring,
+ 'use_sliding_window': use_sliding_window,
+ 'step_size': step_size,
+ 'save_softmax': save_softmax,
+ 'use_gaussian': use_gaussian,
+ 'overwrite': overwrite,
+ 'validation_folder_name': validation_folder_name,
+ 'debug': debug,
+ 'all_in_gpu': all_in_gpu,
+ 'segmentation_export_kwargs': segmentation_export_kwargs,
+ }
+ save_json(my_input_args, join(output_folder, "validation_args.json"))
+
+ if do_mirroring:
+ if not self.data_aug_params['do_mirror']:
+ raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
+ mirror_axes = self.data_aug_params['mirror_axes']
+ else:
+ mirror_axes = ()
+
+ pred_gt_tuples = []
+
+ export_pool = Pool(default_num_threads)
+ results = []
+
+ for k in self.dataset_val.keys():
+ properties = load_pickle(self.dataset[k]['properties_file'])
+ fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
+ if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
+ (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
+ data = np.load(self.dataset[k]['data_file'])['data']
+
+ print(k, data.shape)
+ data[-1][data[-1] == -1] = 0
+
+ softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=self.fp16)[1]
+
+ softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ if save_softmax:
+ softmax_fname = join(output_folder, fname + ".npz")
+ else:
+ softmax_fname = None
+
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
+ np.save(join(output_folder, fname + ".npy"), softmax_pred)
+ softmax_pred = join(output_folder, fname + ".npy")
+
+ results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax_pred, join(output_folder, fname + ".nii.gz"),
+ properties, interpolation_order, self.regions_class_order,
+ None, None,
+ softmax_fname, None, force_separate_z,
+ interpolation_order_z),
+ )
+ )
+ )
+
+ pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
+ join(self.gt_niftis_folder, fname + ".nii.gz")])
+
+ _ = [i.get() for i in results]
+ self.print_to_log_file("finished prediction")
+
+ # evaluate raw predictions
+ self.print_to_log_file("evaluation of raw predictions")
+ task = self.dataset_directory.split("/")[-1]
+ job_name = self.experiment_name
+ _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
+ json_output_file=join(output_folder, "summary.json"),
+ json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
+ json_author="Fabian",
+ json_task=task, num_threads=default_num_threads)
+
+ if run_postprocessing_on_folds:
+ # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
+ # except the largest connected component for each class. To see if this improves results, we do this for all
+ # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
+ # have this applied during inference as well
+ self.print_to_log_file("determining postprocessing")
+ determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
+ final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
+ # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
+ # They are always in that folder, even if no postprocessing as applied!
+
+ # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
+ # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
+ # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
+ # be used later
+ gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
+ maybe_mkdir_p(gt_nifti_folder)
+ for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
+ success = False
+ attempts = 0
+ e = None
+ while not success and attempts < 10:
+ try:
+ shutil.copy(f, gt_nifti_folder)
+ success = True
+ except OSError as e:
+ attempts += 1
+ sleep(1)
+ if not success:
+ print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
+ if e is not None:
+ raise e
+
+ self.network.train(current_mode)
+
+ def run_online_evaluation(self, output, target):
+ with torch.no_grad():
+ num_classes = output.shape[1]
+ output_softmax = softmax_helper(output)
+ output_seg = output_softmax.argmax(1)
+ target = target[:, 0]
+ axes = tuple(range(1, len(target.shape)))
+ tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ for c in range(1, num_classes):
+ tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
+ fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
+ fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
+
+ tp_hard = tp_hard.sum(0, keepdim=False).detach().cpu().numpy()
+ fp_hard = fp_hard.sum(0, keepdim=False).detach().cpu().numpy()
+ fn_hard = fn_hard.sum(0, keepdim=False).detach().cpu().numpy()
+
+ self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
+ self.online_eval_tp.append(list(tp_hard))
+ self.online_eval_fp.append(list(fp_hard))
+ self.online_eval_fn.append(list(fn_hard))
+
+ def finish_online_evaluation(self):
+ self.online_eval_tp = np.sum(self.online_eval_tp, 0)
+ self.online_eval_fp = np.sum(self.online_eval_fp, 0)
+ self.online_eval_fn = np.sum(self.online_eval_fn, 0)
+
+ global_dc_per_class = [i for i in [2 * i / (2 * i + j + k) for i, j, k in
+ zip(self.online_eval_tp, self.online_eval_fp, self.online_eval_fn)]
+ if not np.isnan(i)]
+ self.all_val_eval_metrics.append(np.mean(global_dc_per_class))
+
+ self.print_to_log_file("Average global foreground Dice:", [np.round(i, 4) for i in global_dc_per_class])
+ self.print_to_log_file("(interpret this as an estimate for the Dice of the different classes. This is not "
+ "exact.)")
+
+ self.online_eval_foreground_dc = []
+ self.online_eval_tp = []
+ self.online_eval_fp = []
+ self.online_eval_fn = []
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ super(nnUNetTrainer, self).save_checkpoint(fname, save_optimizer)
+ info = OrderedDict()
+ info['init'] = self.init_args
+ info['name'] = self.__class__.__name__
+ info['class'] = str(self.__class__)
+ info['plans'] = self.plans
+
+ write_pickle(info, fname + ".pkl")
diff --git a/nnunet/training/network_training/nnUNetTrainerCascadeFullRes.py b/nnunet/training/network_training/nnUNetTrainerCascadeFullRes.py
new file mode 100644
index 0000000000000000000000000000000000000000..e362743a87f323e7a1b049f64620ab84371f2016
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerCascadeFullRes.py
@@ -0,0 +1,290 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from multiprocessing.pool import Pool
+from time import sleep
+
+import matplotlib
+from nnunet.postprocessing.connected_components import determine_postprocessing
+from nnunet.training.data_augmentation.default_data_augmentation import get_default_augmentation
+from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.paths import network_training_output_dir
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+from nnunet.utilities.one_hot_encoding import to_one_hot
+import shutil
+
+matplotlib.use("agg")
+
+
+class nnUNetTrainerCascadeFullRes(nnUNetTrainer):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainer", fp16=False):
+ super(nnUNetTrainerCascadeFullRes, self).__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic, fp16)
+ self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, previous_trainer, fp16)
+
+ if self.output_folder is not None:
+ task = self.output_folder.split("/")[-3]
+ plans_identifier = self.output_folder.split("/")[-2].split("__")[-1]
+
+ folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres",
+ task, previous_trainer + "__" + plans_identifier, "pred_next_stage")
+ if not isdir(folder_with_segs_prev_stage):
+ raise RuntimeError(
+ "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the "
+ "segmentations for the next stage")
+ self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage
+ # Do not put segs_prev_stage into self.output_folder as we need to unpack them for performance and we
+ # don't want to do that in self.output_folder because that one is located on some network drive.
+ else:
+ self.folder_with_segs_from_prev_stage = None
+
+ def do_split(self):
+ super(nnUNetTrainerCascadeFullRes, self).do_split()
+ for k in self.dataset:
+ self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+ assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \
+ "seg from prev stage missing: %s" % (self.dataset[k]['seg_from_prev_stage_file'])
+ for k in self.dataset_val:
+ self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+ for k in self.dataset_tr:
+ self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+
+ def get_basic_generators(self):
+ self.load_dataset()
+ self.do_split()
+ if self.threeD:
+ dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
+ True, oversample_foreground_percent=self.oversample_foreground_percent)
+ dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True,
+ oversample_foreground_percent=self.oversample_foreground_percent)
+ else:
+ raise NotImplementedError
+ return dl_tr, dl_val
+
+ def process_plans(self, plans):
+ super(nnUNetTrainerCascadeFullRes, self).process_plans(plans)
+ self.num_input_channels += (self.num_classes - 1) # for seg from prev stage
+
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['move_last_seg_chanel_to_data'] = True
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.4
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+ # we have 2 channels now because the segmentation from the previous stage is stored in 'seg' as well until it
+ # is moved to 'data' at the end
+ self.data_aug_params['selected_seg_channels'] = [0, 1]
+ # needed for converting the segmentation from the previous stage to one hot
+ self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes))
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ For prediction of test cases just set training=False, this will prevent loading of training data and
+ training batchgenerator initialization
+ :param training:
+ :return:
+ """
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.setup_DA_params()
+
+ if self.folder_with_preprocessed_data is not None:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_default_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())))
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())))
+ else:
+ pass
+ self.initialize_network()
+ assert isinstance(self.network, SegmentationNetwork)
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+
+ current_mode = self.network.training
+ self.network.eval()
+
+ assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
+ if self.dataset_val is None:
+ self.load_dataset()
+ self.do_split()
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in self.plans.keys():
+ force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ output_folder = join(self.output_folder, validation_folder_name)
+ maybe_mkdir_p(output_folder)
+
+ if do_mirroring:
+ mirror_axes = self.data_aug_params['mirror_axes']
+ else:
+ mirror_axes = ()
+
+ pred_gt_tuples = []
+
+ export_pool = Pool(2)
+ results = []
+
+ transpose_backward = self.plans.get('transpose_backward')
+
+ for k in self.dataset_val.keys():
+ properties = load_pickle(self.dataset[k]['properties_file'])
+ data = np.load(self.dataset[k]['data_file'])['data']
+
+ # concat segmentation of previous step
+ seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz"))['data'][None]
+
+ print(data.shape)
+ data[-1][data[-1] == -1] = 0
+ data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes))))
+
+ softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net,
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=self.fp16)[1]
+
+ if transpose_backward is not None:
+ transpose_backward = self.plans.get('transpose_backward')
+ softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in transpose_backward])
+
+ fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
+
+ if save_softmax:
+ softmax_fname = join(output_folder, fname + ".npz")
+ else:
+ softmax_fname = None
+
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
+ np.save(fname + ".npy", softmax_pred)
+ softmax_pred = fname + ".npy"
+
+ results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax_pred, join(output_folder, fname + ".nii.gz"),
+ properties, interpolation_order, self.regions_class_order,
+ None, None,
+ softmax_fname, None, force_separate_z,
+ interpolation_order_z),
+ )
+ )
+ )
+
+ pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
+ join(self.gt_niftis_folder, fname + ".nii.gz")])
+
+ _ = [i.get() for i in results]
+
+ task = self.dataset_directory.split("/")[-1]
+ job_name = self.experiment_name
+ _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
+ json_output_file=join(output_folder, "summary.json"), json_name=job_name,
+ json_author="Fabian", json_description="",
+ json_task=task)
+
+ if run_postprocessing_on_folds:
+ # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
+ # except the largest connected component for each class. To see if this improves results, we do this for all
+ # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
+ # have this applied during inference as well
+ self.print_to_log_file("determining postprocessing")
+ determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
+ final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
+ # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
+ # They are always in that folder, even if no postprocessing as applied!
+
+ # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
+ # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
+ # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
+ # be used later
+ gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
+ maybe_mkdir_p(gt_nifti_folder)
+ for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
+ success = False
+ attempts = 0
+ while not success and attempts < 10:
+ try:
+ shutil.copy(f, gt_nifti_folder)
+ success = True
+ except OSError:
+ attempts += 1
+ sleep(1)
+
+ self.network.train(current_mode)
+ export_pool.close()
+ export_pool.join()
\ No newline at end of file
diff --git a/nnunet/training/network_training/nnUNetTrainerV2.py b/nnunet/training/network_training/nnUNetTrainerV2.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5e77e26501d46c1d0276b375d91e2d22e445b2b
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2.py
@@ -0,0 +1,442 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from typing import Tuple
+
+import numpy as np
+import torch
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
+ get_patch_size, default_3D_augmentation_params
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.utilities.nd_softmax import softmax_helper
+from sklearn.model_selection import KFold
+from torch import nn
+from torch.cuda.amp import autocast
+from nnunet.training.learning_rate.poly_lr import poly_lr
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+class nnUNetTrainerV2(nnUNetTrainer):
+ """
+ Info for Fabian: same as internal nnUNetTrainerV2_2
+ """
+
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 1000
+ self.initial_lr = 1e-2
+ self.deep_supervision_scales = None
+ self.ds_loss_weights = None
+
+ self.pin_memory = True
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ - replaced get_default_augmentation with get_moreDA_augmentation
+ - enforce to only run this code once
+ - loss function wrapper for deep supervision
+
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(
+ self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory,
+ use_nondetMultiThreadedAugmenter=False
+ )
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def initialize_network(self):
+ """
+ - momentum 0.99
+ - SGD instead of Adam
+ - self.lr_scheduler = None because we do poly_lr
+ - deep supervision = True
+ - i am sure I forgot something here
+
+ Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ self.lr_scheduler = None
+
+ def run_online_evaluation(self, output, target):
+ """
+ due to deep supervision the return value and the reference are now lists of tensors. We only need the full
+ resolution output because this is what we are interested in in the end. The others are ignored
+ :param output:
+ :param target:
+ :return:
+ """
+ target = target[0]
+ output = output[0]
+ return super().run_online_evaluation(output, target)
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ """
+ We need to wrap this because we need to enforce self.network.do_ds = False for prediction
+ """
+ ds = self.network.do_ds
+ self.network.do_ds = False
+ ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+
+ self.network.do_ds = ds
+ return ret
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ We need to wrap this because we need to enforce self.network.do_ds = False for prediction
+ """
+ ds = self.network.do_ds
+ self.network.do_ds = False
+ ret = super().predict_preprocessed_data_return_seg_and_softmax(data,
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size, use_gaussian=use_gaussian,
+ pad_border_mode=pad_border_mode,
+ pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu,
+ verbose=verbose,
+ mixed_precision=mixed_precision)
+ self.network.do_ds = ds
+ return ret
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ """
+ gradient clipping improves training stability
+
+ :param data_generator:
+ :param do_backprop:
+ :param run_online_evaluation:
+ :return:
+ """
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data)
+ target = to_cuda(target)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def do_split(self):
+ """
+ The default split is a 5 fold CV on all available training cases. nnU-Net will create a split (it is seeded,
+ so always the same) and save it as splits_final.pkl file in the preprocessed data directory.
+ Sometimes you may want to create your own split for various reasons. For this you will need to create your own
+ splits_final.pkl file. If this file is present, nnU-Net is going to use it and whatever splits are defined in
+ it. You can create as many splits in this file as you want. Note that if you define only 4 splits (fold 0-3)
+ and then set fold=4 when training (that would be the fifth split), nnU-Net will print a warning and proceed to
+ use a random 80:20 data split.
+ :return:
+ """
+ if self.fold == "all":
+ # if fold==all then we use all images for training and validation
+ tr_keys = val_keys = list(self.dataset.keys())
+ else:
+ splits_file = join(self.dataset_directory, "splits_final.pkl")
+
+ # if the split file does not exist we need to create it
+ if not isfile(splits_file):
+ self.print_to_log_file("Creating new 5-fold cross-validation split...")
+ splits = []
+ all_keys_sorted = np.sort(list(self.dataset.keys()))
+ kfold = KFold(n_splits=5, shuffle=True, random_state=12345)
+ for i, (train_idx, test_idx) in enumerate(kfold.split(all_keys_sorted)):
+ train_keys = np.array(all_keys_sorted)[train_idx]
+ test_keys = np.array(all_keys_sorted)[test_idx]
+ splits.append(OrderedDict())
+ splits[-1]['train'] = train_keys
+ splits[-1]['val'] = test_keys
+ save_pickle(splits, splits_file)
+
+ else:
+ self.print_to_log_file("Using splits from existing split file:", splits_file)
+ splits = load_pickle(splits_file)
+ self.print_to_log_file("The split file contains %d splits." % len(splits))
+
+ self.print_to_log_file("Desired fold for training: %d" % self.fold)
+ if self.fold < len(splits):
+ tr_keys = splits[self.fold]['train']
+ val_keys = splits[self.fold]['val']
+ self.print_to_log_file("This split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+ else:
+ self.print_to_log_file("INFO: You requested fold %d for training but splits "
+ "contain only %d folds. I am now creating a "
+ "random (but seeded) 80:20 split!" % (self.fold, len(splits)))
+ # if we request a fold that is not in the split file, create a random 80:20 split
+ rnd = np.random.RandomState(seed=12345 + self.fold)
+ keys = np.sort(list(self.dataset.keys()))
+ idx_tr = rnd.choice(len(keys), int(len(keys) * 0.8), replace=False)
+ idx_val = [i for i in range(len(keys)) if i not in idx_tr]
+ tr_keys = [keys[i] for i in idx_tr]
+ val_keys = [keys[i] for i in idx_val]
+ self.print_to_log_file("This random 80:20 split has %d training and %d validation cases."
+ % (len(tr_keys), len(val_keys)))
+
+ tr_keys.sort()
+ val_keys.sort()
+ self.dataset_tr = OrderedDict()
+ for i in tr_keys:
+ self.dataset_tr[i] = self.dataset[i]
+ self.dataset_val = OrderedDict()
+ for i in val_keys:
+ self.dataset_val[i] = self.dataset[i]
+
+ def setup_DA_params(self):
+ """
+ - we increase roation angle from [-15, 15] to [-30, 30]
+ - scale range is now (0.7, 1.4), was (0.85, 1.25)
+ - we don't do elastic deformation anymore
+
+ :return:
+ """
+
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
+
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params["scale_range"] = (0.7, 1.4)
+ self.data_aug_params["do_elastic"] = False
+ self.data_aug_params['selected_seg_channels'] = [0]
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ self.data_aug_params["num_cached_per_thread"] = 2
+
+ def maybe_update_lr(self, epoch=None):
+ """
+ if epoch is not None we overwrite epoch. Else we use epoch = self.epoch + 1
+
+ (maybe_update_lr is called in on_epoch_end which is called before epoch is incremented.
+ herefore we need to do +1 here)
+
+ :param epoch:
+ :return:
+ """
+ if epoch is None:
+ ep = self.epoch + 1
+ else:
+ ep = epoch
+ self.optimizer.param_groups[0]['lr'] = poly_lr(ep, self.max_num_epochs, self.initial_lr, 0.9)
+ self.print_to_log_file("lr:", np.round(self.optimizer.param_groups[0]['lr'], decimals=6))
+
+ def on_epoch_end(self):
+ """
+ overwrite patient-based early stopping. Always run to 1000 epochs
+ :return:
+ """
+ super().on_epoch_end()
+ continue_training = self.epoch < self.max_num_epochs
+
+ # it can rarely happen that the momentum of nnUNetTrainerV2 is too high for some dataset. If at epoch 100 the
+ # estimated validation Dice is still 0 then we reduce the momentum from 0.99 to 0.95
+ if self.epoch == 100:
+ if self.all_val_eval_metrics[-1] == 0:
+ self.optimizer.param_groups[0]["momentum"] = 0.95
+ self.network.apply(InitWeights_He(1e-2))
+ self.print_to_log_file("At epoch 100, the mean foreground Dice was 0. This can be caused by a too "
+ "high momentum. High momentum (0.99) is good for datasets where it works, but "
+ "sometimes causes issues such as this one. Momentum has now been reduced to "
+ "0.95 and network weights have been reinitialized")
+ return continue_training
+
+ def run_training(self):
+ """
+ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
+ continued epoch with self.initial_lr
+
+ we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
+ :return:
+ """
+ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
+ # want at the start of the training
+ ds = self.network.do_ds
+ self.network.do_ds = True
+ ret = super().run_training()
+ self.network.do_ds = ds
+ return ret
diff --git a/nnunet/training/network_training/nnUNetTrainerV2_100epochs.py b/nnunet/training/network_training/nnUNetTrainerV2_100epochs.py
new file mode 100644
index 0000000000000000000000000000000000000000..e16f60ac09b33224380c11424fd099102420fd1d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2_100epochs.py
@@ -0,0 +1,127 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from collections import OrderedDict
+from typing import Tuple
+from nnunet.training.loss_functions.dice_loss import DC_and_Focal_loss
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.loss_functions.focal_loss import FocalLossV2
+import numpy as np
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+import torch
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.default_data_augmentation import default_2D_augmentation_params, \
+ get_patch_size, default_3D_augmentation_params
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.utilities.nd_softmax import softmax_helper
+from sklearn.model_selection import KFold
+from torch import nn
+from torch.cuda.amp import autocast
+from nnunet.training.learning_rate.poly_lr import poly_lr
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+class nnUNetTrainerV2_100epochs(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 100
+
+class nnUNetTrainerV2_150epochs(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 150
+
+class nnUNetTrainerV2_100epochs_CEnoDS(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 100
+ self.loss = RobustCrossEntropyLoss()
+
+class nnUNetTrainerV2CascadeFullRes_100epochs(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2_100epochs", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic, previous_trainer, fp16)
+ self.max_num_epochs = 100
+
+class nnUNetTrainerV2CascadeFullRes_150epochs(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2_150epochs", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic, previous_trainer, fp16)
+ self.max_num_epochs = 150
+
+class nnUNetTrainerV2CascadeFullRes_100epochs_CEnoDS(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 100
+ self.loss = RobustCrossEntropyLoss()
+
+
+class nnUNetTrainerV2_100epochs_FocalLoss(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ print(
+ "Focal loss parameters: {'alpha':0.25, 'gamma':2, 'smooth':1e-5}")
+ self.max_num_epochs = 100
+ self.loss = FocalLossV2(apply_nonlin=nn.Softmax(
+ dim=1), **{'alpha': 0.25, 'gamma': 2, 'smooth': 1e-5})
+
+
+class nnUNetTrainerV2_100epochs_DCFocalLoss(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ print(
+ "Focal loss parameters: {'alpha':0.25, 'gamma':2, 'smooth':1e-5}")
+ self.max_num_epochs = 100
+ self.loss = DC_and_Focal_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5,
+ 'do_bg': False}, {'alpha': 0.25, 'gamma': 2, 'smooth': 1e-5})
+
+class nnUNetTrainerV2_150epochs_FocalLoss(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ print("Focal loss parameters: {'alpha':0.25, 'gamma':2, 'smooth':1e-5}")
+ self.max_num_epochs = 150
+ self.loss = FocalLossV2(apply_nonlin=nn.Softmax(dim=1), **{'alpha':0.25, 'gamma':2, 'smooth':1e-5})
+
+class nnUNetTrainerV2_150epochs_DCFocalLoss(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ print("Focal loss parameters: {'alpha':0.25, 'gamma':2, 'smooth':1e-5}")
+ self.max_num_epochs = 150
+ self.loss = DC_and_Focal_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {'alpha':0.25, 'gamma':2, 'smooth':1e-5})
diff --git a/nnunet/training/network_training/nnUNetTrainerV2_CascadeFullRes.py b/nnunet/training/network_training/nnUNetTrainerV2_CascadeFullRes.py
new file mode 100644
index 0000000000000000000000000000000000000000..b855f3752edbe16196d2d695142afdc21a4f3afb
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2_CascadeFullRes.py
@@ -0,0 +1,353 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from multiprocessing.pool import Pool
+from time import sleep
+import matplotlib
+from nnunet.configuration import default_num_threads
+from nnunet.postprocessing.connected_components import determine_postprocessing
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.dataloading.dataset_loading import DataLoader3D, unpack_dataset
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.paths import network_training_output_dir
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.one_hot_encoding import to_one_hot
+import shutil
+
+from torch import nn
+
+matplotlib.use("agg")
+
+
+class nnUNetTrainerV2CascadeFullRes(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic, fp16)
+ self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, previous_trainer, fp16)
+
+ if self.output_folder is not None:
+ task = self.output_folder.split("/")[-3]
+ plans_identifier = self.output_folder.split("/")[-2].split("__")[-1]
+
+ folder_with_segs_prev_stage = join(network_training_output_dir, "3d_lowres",
+ task, previous_trainer + "__" + plans_identifier, "pred_next_stage")
+ self.folder_with_segs_from_prev_stage = folder_with_segs_prev_stage
+ # Do not put segs_prev_stage into self.output_folder as we need to unpack them for performance and we
+ # don't want to do that in self.output_folder because that one is located on some network drive.
+ else:
+ self.folder_with_segs_from_prev_stage = None
+
+ def do_split(self):
+ super().do_split()
+ for k in self.dataset:
+ self.dataset[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+ assert isfile(self.dataset[k]['seg_from_prev_stage_file']), \
+ "seg from prev stage missing: %s. " \
+ "Please run all 5 folds of the 3d_lowres configuration of this " \
+ "task!" % (self.dataset[k]['seg_from_prev_stage_file'])
+ for k in self.dataset_val:
+ self.dataset_val[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+ for k in self.dataset_tr:
+ self.dataset_tr[k]['seg_from_prev_stage_file'] = join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz")
+
+ def get_basic_generators(self):
+ self.load_dataset()
+ self.do_split()
+
+ if self.threeD:
+ dl_tr = DataLoader3D(self.dataset_tr, self.basic_generator_patch_size, self.patch_size, self.batch_size,
+ True, oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, True,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ else:
+ raise NotImplementedError("2D has no cascade")
+
+ return dl_tr, dl_val
+
+ def process_plans(self, plans):
+ super().process_plans(plans)
+ self.num_input_channels += (self.num_classes - 1) # for seg from prev stage
+
+ def setup_DA_params(self):
+ super().setup_DA_params()
+
+ self.data_aug_params["num_cached_per_thread"] = 2
+
+ self.data_aug_params['move_last_seg_chanel_to_data'] = True
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.4
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+ # we have 2 channels now because the segmentation from the previous stage is stored in 'seg' as well until it
+ # is moved to 'data' at the end
+ self.data_aug_params['selected_seg_channels'] = [0, 1]
+ # needed for converting the segmentation from the previous stage to one hot
+ self.data_aug_params['all_segmentation_labels'] = list(range(1, self.num_classes))
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ For prediction of test cases just set training=False, this will prevent loading of training data and
+ training batchgenerator initialization
+ :param training:
+ :return:
+ """
+ if not self.was_initialized:
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+
+ if training:
+ if not isdir(self.folder_with_segs_from_prev_stage):
+ raise RuntimeError(
+ "Cannot run final stage of cascade. Run corresponding 3d_lowres first and predict the "
+ "segmentations for the next stage")
+
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
+
+ current_mode = self.network.training
+ self.network.eval()
+ # save whether network is in deep supervision mode or not
+ ds = self.network.do_ds
+ # disable deep supervision
+ self.network.do_ds = False
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in self.plans.keys():
+ force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ if self.dataset_val is None:
+ self.load_dataset()
+ self.do_split()
+
+ output_folder = join(self.output_folder, validation_folder_name)
+ maybe_mkdir_p(output_folder)
+ # this is for debug purposes
+ my_input_args = {'do_mirroring': do_mirroring,
+ 'use_sliding_window': use_sliding_window,
+ 'step': step_size,
+ 'save_softmax': save_softmax,
+ 'use_gaussian': use_gaussian,
+ 'overwrite': overwrite,
+ 'validation_folder_name': validation_folder_name,
+ 'debug': debug,
+ 'all_in_gpu': all_in_gpu,
+ 'segmentation_export_kwargs': segmentation_export_kwargs,
+ }
+ save_json(my_input_args, join(output_folder, "validation_args.json"))
+
+ if do_mirroring:
+ if not self.data_aug_params['do_mirror']:
+ raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
+ mirror_axes = self.data_aug_params['mirror_axes']
+ else:
+ mirror_axes = ()
+
+ pred_gt_tuples = []
+
+ export_pool = Pool(default_num_threads)
+ results = []
+
+ for k in self.dataset_val.keys():
+ properties = load_pickle(self.dataset[k]['properties_file'])
+ fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
+
+ if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
+ (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
+ data = np.load(self.dataset[k]['data_file'])['data']
+
+ # concat segmentation of previous step
+ seg_from_prev_stage = np.load(join(self.folder_with_segs_from_prev_stage,
+ k + "_segFromPrevStage.npz"))['data'][None]
+
+ print(k, data.shape)
+ data[-1][data[-1] == -1] = 0
+
+ data_for_net = np.concatenate((data[:-1], to_one_hot(seg_from_prev_stage[0], range(1, self.num_classes))))
+
+ softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data_for_net,
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=self.fp16)[1]
+
+ softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ if save_softmax:
+ softmax_fname = join(output_folder, fname + ".npz")
+ else:
+ softmax_fname = None
+
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
+ np.save(join(output_folder, fname + ".npy"), softmax_pred)
+ softmax_pred = join(output_folder, fname + ".npy")
+
+ results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax_pred, join(output_folder, fname + ".nii.gz"),
+ properties, interpolation_order, None, None, None,
+ softmax_fname, None, force_separate_z,
+ interpolation_order_z),
+ )
+ )
+ )
+
+ pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
+ join(self.gt_niftis_folder, fname + ".nii.gz")])
+
+ _ = [i.get() for i in results]
+ self.print_to_log_file("finished prediction")
+
+ # evaluate raw predictions
+ self.print_to_log_file("evaluation of raw predictions")
+ task = self.dataset_directory.split("/")[-1]
+ job_name = self.experiment_name
+ _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
+ json_output_file=join(output_folder, "summary.json"),
+ json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
+ json_author="Fabian",
+ json_task=task, num_threads=default_num_threads)
+
+ if run_postprocessing_on_folds:
+ # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
+ # except the largest connected component for each class. To see if this improves results, we do this for all
+ # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
+ # have this applied during inference as well
+ self.print_to_log_file("determining postprocessing")
+ determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
+ final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
+ # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
+ # They are always in that folder, even if no postprocessing as applied!
+
+ # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
+ # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
+ # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
+ # be used later
+ gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
+ maybe_mkdir_p(gt_nifti_folder)
+ for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
+ success = False
+ attempts = 0
+ e = None
+ while not success and attempts < 10:
+ try:
+ shutil.copy(f, gt_nifti_folder)
+ success = True
+ except OSError as e:
+ attempts += 1
+ sleep(1)
+ if not success:
+ print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
+ if e is not None:
+ raise e
+
+ # restore network deep supervision mode
+ self.network.train(current_mode)
+ self.network.do_ds = ds
diff --git a/nnunet/training/network_training/nnUNetTrainerV2_DDP.py b/nnunet/training/network_training/nnUNetTrainerV2_DDP.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d883fceb1f7c081c5e4eb001265baa992940d2
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2_DDP.py
@@ -0,0 +1,686 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import os
+import shutil
+from _warnings import warn
+from collections import OrderedDict
+from multiprocessing import Pool
+from time import sleep, time
+from typing import Tuple
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join, subfiles, isfile, load_pickle, \
+ save_json
+from nnunet.configuration import default_num_threads
+from nnunet.evaluation.evaluator import aggregate_scores
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.postprocessing.connected_components import determine_postprocessing
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.loss_functions.dice_loss import get_tp_fp_fn_tn
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.distributed import awesome_allgather_function
+from nnunet.utilities.nd_softmax import softmax_helper
+from nnunet.utilities.tensor_utilities import sum_tensor
+from nnunet.utilities.to_torch import to_cuda, maybe_to_torch
+from torch import nn, distributed
+from torch.backends import cudnn
+from torch.cuda.amp import autocast
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim.lr_scheduler import _LRScheduler
+from tqdm import trange
+
+
+class nnUNetTrainerV2_DDP(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True,
+ stage=None,
+ unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
+ unpack_data, deterministic, fp16)
+ self.init_args = (
+ plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, distribute_batch_size, fp16)
+ self.distribute_batch_size = distribute_batch_size
+ np.random.seed(local_rank)
+ torch.manual_seed(local_rank)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(local_rank)
+ self.local_rank = local_rank
+
+ if torch.cuda.is_available():
+ torch.cuda.set_device(local_rank)
+ dist.init_process_group(backend='nccl', init_method='env://')
+
+ self.loss = None
+ self.ce_loss = RobustCrossEntropyLoss()
+
+ self.global_batch_size = None # we need to know this to properly steer oversample
+
+ def set_batch_size_and_oversample(self):
+ batch_sizes = []
+ oversample_percents = []
+
+ world_size = dist.get_world_size()
+ my_rank = dist.get_rank()
+
+ if self.distribute_batch_size:
+ self.global_batch_size = self.batch_size
+ else:
+ self.global_batch_size = self.batch_size * world_size
+
+ batch_size_per_GPU = np.ceil(self.batch_size / world_size).astype(int)
+
+ for rank in range(world_size):
+ if self.distribute_batch_size:
+ if (rank + 1) * batch_size_per_GPU > self.batch_size:
+ batch_size = batch_size_per_GPU - ((rank + 1) * batch_size_per_GPU - self.batch_size)
+ else:
+ batch_size = batch_size_per_GPU
+ else:
+ batch_size = self.batch_size
+
+ batch_sizes.append(batch_size)
+
+ sample_id_low = 0 if len(batch_sizes) == 0 else np.sum(batch_sizes[:-1])
+ sample_id_high = np.sum(batch_sizes)
+
+ if sample_id_high / self.global_batch_size < (1 - self.oversample_foreground_percent):
+ oversample_percents.append(0.0)
+ elif sample_id_low / self.global_batch_size > (1 - self.oversample_foreground_percent):
+ oversample_percents.append(1.0)
+ else:
+ percent_covered_by_this_rank = sample_id_high / self.global_batch_size - sample_id_low / self.global_batch_size
+ oversample_percent_here = 1 - (((1 - self.oversample_foreground_percent) -
+ sample_id_low / self.global_batch_size) / percent_covered_by_this_rank)
+ oversample_percents.append(oversample_percent_here)
+
+ print("worker", my_rank, "oversample", oversample_percents[my_rank])
+ print("worker", my_rank, "batch_size", batch_sizes[my_rank])
+
+ self.batch_size = batch_sizes[my_rank]
+ self.oversample_foreground_percent = oversample_percents[my_rank]
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ if self.local_rank == 0:
+ super().save_checkpoint(fname, save_optimizer)
+
+ def plot_progress(self):
+ if self.local_rank == 0:
+ super().plot_progress()
+
+ def print_to_log_file(self, *args, also_print_to_console=True):
+ if self.local_rank == 0:
+ super().print_to_log_file(*args, also_print_to_console=also_print_to_console)
+
+ def process_plans(self, plans):
+ super().process_plans(plans)
+ self.set_batch_size_and_oversample()
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ :param training:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ if self.local_rank == 0:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ distributed.barrier()
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ # setting weights for deep supervision losses
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+
+ seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads'))
+ seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1))
+ print("seeds train", seeds_train)
+ print("seeds_val", seeds_val)
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ seeds_train=seeds_train,
+ seeds_val=seeds_val,
+ pin_memory=self.pin_memory)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+ self.network = DDP(self.network, device_ids=[self.local_rank])
+
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data, gpu_id=None)
+ target = to_cuda(target, gpu_id=None)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.compute_loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.compute_loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def compute_loss(self, output, target):
+ total_loss = None
+ for i in range(len(output)):
+ # Starting here it gets spicy!
+ axes = tuple(range(2, len(output[i].size())))
+
+ # network does not do softmax. We need to do softmax for dice
+ output_softmax = softmax_helper(output[i])
+
+ # get the tp, fp and fn terms we need
+ tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None)
+ # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables
+ # do_bg=False in nnUNetTrainer -> [:, 1:]
+ nominator = 2 * tp[:, 1:]
+ denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:]
+
+ if self.batch_dice:
+ # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice
+ nominator = awesome_allgather_function.apply(nominator)
+ denominator = awesome_allgather_function.apply(denominator)
+ nominator = nominator.sum(0)
+ denominator = denominator.sum(0)
+ else:
+ pass
+
+ ce_loss = self.ce_loss(output[i], target[i][:, 0].long())
+
+ # we smooth by 1e-5 to penalize false positives if tp is 0
+ dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean()
+ if total_loss is None:
+ total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss)
+ else:
+ total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss)
+ return total_loss
+
+ def run_online_evaluation(self, output, target):
+ with torch.no_grad():
+ num_classes = output[0].shape[1]
+ output_seg = output[0].argmax(1)
+ target = target[0][:, 0]
+ axes = tuple(range(1, len(target.shape)))
+ tp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fp_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ fn_hard = torch.zeros((target.shape[0], num_classes - 1)).to(output_seg.device.index)
+ for c in range(1, num_classes):
+ tp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target == c).float(), axes=axes)
+ fp_hard[:, c - 1] = sum_tensor((output_seg == c).float() * (target != c).float(), axes=axes)
+ fn_hard[:, c - 1] = sum_tensor((output_seg != c).float() * (target == c).float(), axes=axes)
+
+ # tp_hard, fp_hard, fn_hard = get_tp_fp_fn((output_softmax > (1 / num_classes)).float(), target,
+ # axes, None)
+ # print_if_rank0("before allgather", tp_hard.shape)
+ tp_hard = tp_hard.sum(0, keepdim=False)[None]
+ fp_hard = fp_hard.sum(0, keepdim=False)[None]
+ fn_hard = fn_hard.sum(0, keepdim=False)[None]
+
+ tp_hard = awesome_allgather_function.apply(tp_hard)
+ fp_hard = awesome_allgather_function.apply(fp_hard)
+ fn_hard = awesome_allgather_function.apply(fn_hard)
+
+ tp_hard = tp_hard.detach().cpu().numpy().sum(0)
+ fp_hard = fp_hard.detach().cpu().numpy().sum(0)
+ fn_hard = fn_hard.detach().cpu().numpy().sum(0)
+ self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
+ self.online_eval_tp.append(list(tp_hard))
+ self.online_eval_fp.append(list(fp_hard))
+ self.online_eval_fn.append(list(fn_hard))
+
+ def run_training(self):
+ """
+ if we run with -c then we need to set the correct lr for the first epoch, otherwise it will run the first
+ continued epoch with self.initial_lr
+
+ we also need to make sure deep supervision in the network is enabled for training, thus the wrapper
+ :return:
+ """
+ if self.local_rank == 0:
+ self.save_debug_information()
+
+ if not torch.cuda.is_available():
+ self.print_to_log_file("WARNING!!! You are attempting to run training on a CPU (torch.cuda.is_available() is False). This can be VERY slow!")
+
+ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
+ # want at the start of the training
+ if isinstance(self.network, DDP):
+ net = self.network.module
+ else:
+ net = self.network
+ ds = net.do_ds
+ net.do_ds = True
+
+ _ = self.tr_gen.next()
+ _ = self.val_gen.next()
+
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ self._maybe_init_amp()
+
+ maybe_mkdir_p(self.output_folder)
+ self.plot_network_architecture()
+
+ if cudnn.benchmark and cudnn.deterministic:
+ warn("torch.backends.cudnn.deterministic is True indicating a deterministic training is desired. "
+ "But torch.backends.cudnn.benchmark is True as well and this will prevent deterministic training! "
+ "If you want deterministic then set benchmark=False")
+
+ if not self.was_initialized:
+ self.initialize(True)
+
+ while self.epoch < self.max_num_epochs:
+ self.print_to_log_file("\nepoch: ", self.epoch)
+ epoch_start_time = time()
+ train_losses_epoch = []
+
+ # train one epoch
+ self.network.train()
+
+ if self.use_progress_bar:
+ with trange(self.num_batches_per_epoch) as tbar:
+ for b in tbar:
+ tbar.set_description("Epoch {}/{}".format(self.epoch+1, self.max_num_epochs))
+
+ l = self.run_iteration(self.tr_gen, True)
+
+ tbar.set_postfix(loss=l)
+ train_losses_epoch.append(l)
+ else:
+ for _ in range(self.num_batches_per_epoch):
+ l = self.run_iteration(self.tr_gen, True)
+ train_losses_epoch.append(l)
+
+ self.all_tr_losses.append(np.mean(train_losses_epoch))
+ self.print_to_log_file("train loss : %.4f" % self.all_tr_losses[-1])
+
+ with torch.no_grad():
+ # validation with train=False
+ self.network.eval()
+ val_losses = []
+ for b in range(self.num_val_batches_per_epoch):
+ l = self.run_iteration(self.val_gen, False, True)
+ val_losses.append(l)
+ self.all_val_losses.append(np.mean(val_losses))
+ self.print_to_log_file("validation loss: %.4f" % self.all_val_losses[-1])
+
+ if self.also_val_in_tr_mode:
+ self.network.train()
+ # validation with train=True
+ val_losses = []
+ for b in range(self.num_val_batches_per_epoch):
+ l = self.run_iteration(self.val_gen, False)
+ val_losses.append(l)
+ self.all_val_losses_tr_mode.append(np.mean(val_losses))
+ self.print_to_log_file("validation loss (train=True): %.4f" % self.all_val_losses_tr_mode[-1])
+
+ self.update_train_loss_MA() # needed for lr scheduler and stopping of training
+
+ continue_training = self.on_epoch_end()
+
+ epoch_end_time = time()
+
+ if not continue_training:
+ # allows for early stopping
+ break
+
+ self.epoch += 1
+ self.print_to_log_file("This epoch took %f s\n" % (epoch_end_time - epoch_start_time))
+
+ self.epoch -= 1 # if we don't do this we can get a problem with loading model_final_checkpoint.
+
+ if self.save_final_checkpoint: self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
+
+ if self.local_rank == 0:
+ # now we can delete latest as it will be identical with final
+ if isfile(join(self.output_folder, "model_latest.model")):
+ os.remove(join(self.output_folder, "model_latest.model"))
+ if isfile(join(self.output_folder, "model_latest.model.pkl")):
+ os.remove(join(self.output_folder, "model_latest.model.pkl"))
+
+ net.do_ds = ds
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ if isinstance(self.network, DDP):
+ net = self.network.module
+ else:
+ net = self.network
+ ds = net.do_ds
+ net.do_ds = False
+
+ current_mode = self.network.training
+ self.network.eval()
+
+ assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
+ if self.dataset_val is None:
+ self.load_dataset()
+ self.do_split()
+
+ if segmentation_export_kwargs is None:
+ if 'segmentation_export_params' in self.plans.keys():
+ force_separate_z = self.plans['segmentation_export_params']['force_separate_z']
+ interpolation_order = self.plans['segmentation_export_params']['interpolation_order']
+ interpolation_order_z = self.plans['segmentation_export_params']['interpolation_order_z']
+ else:
+ force_separate_z = None
+ interpolation_order = 1
+ interpolation_order_z = 0
+ else:
+ force_separate_z = segmentation_export_kwargs['force_separate_z']
+ interpolation_order = segmentation_export_kwargs['interpolation_order']
+ interpolation_order_z = segmentation_export_kwargs['interpolation_order_z']
+
+ # predictions as they come from the network go here
+ output_folder = join(self.output_folder, validation_folder_name)
+ maybe_mkdir_p(output_folder)
+ # this is for debug purposes
+ my_input_args = {'do_mirroring': do_mirroring,
+ 'use_sliding_window': use_sliding_window,
+ 'step_size': step_size,
+ 'save_softmax': save_softmax,
+ 'use_gaussian': use_gaussian,
+ 'overwrite': overwrite,
+ 'validation_folder_name': validation_folder_name,
+ 'debug': debug,
+ 'all_in_gpu': all_in_gpu,
+ 'segmentation_export_kwargs': segmentation_export_kwargs,
+ }
+ save_json(my_input_args, join(output_folder, "validation_args.json"))
+
+ if do_mirroring:
+ if not self.data_aug_params['do_mirror']:
+ raise RuntimeError(
+ "We did not train with mirroring so you cannot do inference with mirroring enabled")
+ mirror_axes = self.data_aug_params['mirror_axes']
+ else:
+ mirror_axes = ()
+
+ pred_gt_tuples = []
+
+ export_pool = Pool(default_num_threads)
+ results = []
+
+ all_keys = list(self.dataset_val.keys())
+ my_keys = all_keys[self.local_rank::dist.get_world_size()]
+ # we cannot simply iterate over all_keys because we need to know pred_gt_tuples and valid_labels of all cases
+ # for evaluation (which is done by local rank 0)
+ for k in my_keys:
+ properties = load_pickle(self.dataset[k]['properties_file'])
+ fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
+ pred_gt_tuples.append([join(output_folder, fname + ".nii.gz"),
+ join(self.gt_niftis_folder, fname + ".nii.gz")])
+ if k in my_keys:
+ if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
+ (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
+ data = np.load(self.dataset[k]['data_file'])['data']
+
+ print(k, data.shape)
+ data[-1][data[-1] == -1] = 0
+
+ softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ all_in_gpu=all_in_gpu,
+ mixed_precision=self.fp16)[1]
+
+ softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ if save_softmax:
+ softmax_fname = join(output_folder, fname + ".npz")
+ else:
+ softmax_fname = None
+
+ """There is a problem with python process communication that prevents us from communicating objects
+ larger than 2 GB between processes (basically when the length of the pickle string that will be sent is
+ communicated by the multiprocessing.Pipe object then the placeholder (I think) does not allow for long
+ enough strings (lol). This could be fixed by changing i to l (for long) but that would require manually
+ patching system python code. We circumvent that problem here by saving softmax_pred to a npy file that will
+ then be read (and finally deleted) by the Process. save_segmentation_nifti_from_softmax can take either
+ filename or np.ndarray and will handle this automatically"""
+ if np.prod(softmax_pred.shape) > (2e9 / 4 * 0.85): # *0.85 just to be save
+ np.save(join(output_folder, fname + ".npy"), softmax_pred)
+ softmax_pred = join(output_folder, fname + ".npy")
+
+ results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax_pred, join(output_folder, fname + ".nii.gz"),
+ properties, interpolation_order,
+ self.regions_class_order,
+ None, None,
+ softmax_fname, None, force_separate_z,
+ interpolation_order_z),
+ )
+ )
+ )
+
+ _ = [i.get() for i in results]
+ self.print_to_log_file("finished prediction")
+
+ distributed.barrier()
+
+ if self.local_rank == 0:
+ # evaluate raw predictions
+ self.print_to_log_file("evaluation of raw predictions")
+ task = self.dataset_directory.split("/")[-1]
+ job_name = self.experiment_name
+ _ = aggregate_scores(pred_gt_tuples, labels=list(range(self.num_classes)),
+ json_output_file=join(output_folder, "summary.json"),
+ json_name=job_name + " val tiled %s" % (str(use_sliding_window)),
+ json_author="Fabian",
+ json_task=task, num_threads=default_num_threads)
+
+ if run_postprocessing_on_folds:
+ # in the old nnunet we would stop here. Now we add a postprocessing. This postprocessing can remove everything
+ # except the largest connected component for each class. To see if this improves results, we do this for all
+ # classes and then rerun the evaluation. Those classes for which this resulted in an improved dice score will
+ # have this applied during inference as well
+ self.print_to_log_file("determining postprocessing")
+ determine_postprocessing(self.output_folder, self.gt_niftis_folder, validation_folder_name,
+ final_subf_name=validation_folder_name + "_postprocessed", debug=debug)
+ # after this the final predictions for the vlaidation set can be found in validation_folder_name_base + "_postprocessed"
+ # They are always in that folder, even if no postprocessing as applied!
+
+ # detemining postprocesing on a per-fold basis may be OK for this fold but what if another fold finds another
+ # postprocesing to be better? In this case we need to consolidate. At the time the consolidation is going to be
+ # done we won't know what self.gt_niftis_folder was, so now we copy all the niftis into a separate folder to
+ # be used later
+ gt_nifti_folder = join(self.output_folder_base, "gt_niftis")
+ maybe_mkdir_p(gt_nifti_folder)
+ for f in subfiles(self.gt_niftis_folder, suffix=".nii.gz"):
+ success = False
+ attempts = 0
+ e = None
+ while not success and attempts < 10:
+ try:
+ shutil.copy(f, gt_nifti_folder)
+ success = True
+ except OSError as e:
+ attempts += 1
+ sleep(1)
+ if not success:
+ print("Could not copy gt nifti file %s into folder %s" % (f, gt_nifti_folder))
+ if e is not None:
+ raise e
+
+ self.network.train(current_mode)
+ net.do_ds = ds
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[
+ np.ndarray, np.ndarray]:
+ if pad_border_mode == 'constant' and pad_kwargs is None:
+ pad_kwargs = {'constant_values': 0}
+
+ if do_mirroring and mirror_axes is None:
+ mirror_axes = self.data_aug_params['mirror_axes']
+
+ if do_mirroring:
+ assert self.data_aug_params["do_mirror"], "Cannot do mirroring as test time augmentation when training " \
+ "was done without mirroring"
+
+ valid = list((SegmentationNetwork, nn.DataParallel, DDP))
+ assert isinstance(self.network, tuple(valid))
+ if isinstance(self.network, DDP):
+ net = self.network.module
+ else:
+ net = self.network
+ ds = net.do_ds
+ net.do_ds = False
+ ret = net.predict_3D(data, do_mirroring=do_mirroring, mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window, step_size=step_size,
+ patch_size=self.patch_size, regions_class_order=self.regions_class_order,
+ use_gaussian=use_gaussian, pad_border_mode=pad_border_mode,
+ pad_kwargs=pad_kwargs, all_in_gpu=all_in_gpu, verbose=verbose,
+ mixed_precision=mixed_precision)
+ net.do_ds = ds
+ return ret
+
+ def load_checkpoint_ram(self, checkpoint, train=True):
+ """
+ used for if the checkpoint is already in ram
+ :param checkpoint:
+ :param train:
+ :return:
+ """
+ if not self.was_initialized:
+ self.initialize(train)
+
+ new_state_dict = OrderedDict()
+ curr_state_dict_keys = list(self.network.state_dict().keys())
+ # if state dict comes from nn.DataParallel but we use non-parallel model here then the state dict keys do not
+ # match. Use heuristic to make it match
+ for k, value in checkpoint['state_dict'].items():
+ key = k
+ if key not in curr_state_dict_keys:
+ print("duh")
+ key = key[7:]
+ new_state_dict[key] = value
+
+ if self.fp16:
+ self._maybe_init_amp()
+ if 'amp_grad_scaler' in checkpoint.keys():
+ self.amp_grad_scaler.load_state_dict(checkpoint['amp_grad_scaler'])
+
+ self.network.load_state_dict(new_state_dict)
+ self.epoch = checkpoint['epoch']
+ if train:
+ optimizer_state_dict = checkpoint['optimizer_state_dict']
+ if optimizer_state_dict is not None:
+ self.optimizer.load_state_dict(optimizer_state_dict)
+
+ if self.lr_scheduler is not None and hasattr(self.lr_scheduler, 'load_state_dict') and checkpoint[
+ 'lr_scheduler_state_dict'] is not None:
+ self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict'])
+
+ if issubclass(self.lr_scheduler.__class__, _LRScheduler):
+ self.lr_scheduler.step(self.epoch)
+
+ self.all_tr_losses, self.all_val_losses, self.all_val_losses_tr_mode, self.all_val_eval_metrics = checkpoint[
+ 'plot_stuff']
+
+ # after the training is done, the epoch is incremented one more time in my old code. This results in
+ # self.epoch = 1001 for old trained models when the epoch is actually 1000. This causes issues because
+ # len(self.all_tr_losses) = 1000 and the plot function will fail. We can easily detect and correct that here
+ if self.epoch != len(self.all_tr_losses):
+ self.print_to_log_file("WARNING in loading checkpoint: self.epoch != len(self.all_tr_losses). This is "
+ "due to an old bug and should only appear when you are loading old models. New "
+ "models should have this fixed! self.epoch is now set to len(self.all_tr_losses)")
+ self.epoch = len(self.all_tr_losses)
+ self.all_tr_losses = self.all_tr_losses[:self.epoch]
+ self.all_val_losses = self.all_val_losses[:self.epoch]
+ self.all_val_losses_tr_mode = self.all_val_losses_tr_mode[:self.epoch]
+ self.all_val_eval_metrics = self.all_val_eval_metrics[:self.epoch]
diff --git a/nnunet/training/network_training/nnUNetTrainerV2_DP.py b/nnunet/training/network_training/nnUNetTrainerV2_DP.py
new file mode 100644
index 0000000000000000000000000000000000000000..0af5c986733b5ea95dca013ce72c49e6f58303c5
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2_DP.py
@@ -0,0 +1,256 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+import torch
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.network_architecture.generic_UNet_DP import Generic_UNet_DP
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+from torch.cuda.amp import autocast
+from torch.nn.parallel.data_parallel import DataParallel
+
+
+class nnUNetTrainerV2_DP(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, num_gpus=1, distribute_batch_size=False, fp16=False):
+ super(nnUNetTrainerV2_DP, self).__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
+ unpack_data, deterministic, fp16)
+ self.init_args = (plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, num_gpus, distribute_batch_size, fp16)
+ self.num_gpus = num_gpus
+ self.distribute_batch_size = distribute_batch_size
+ self.dice_smooth = 1e-5
+ self.dice_do_BG = False
+ self.loss = None
+ self.loss_weights = None
+
+ def setup_DA_params(self):
+ super(nnUNetTrainerV2_DP, self).setup_DA_params()
+ self.data_aug_params['num_threads'] = 8 * self.num_gpus
+
+ def process_plans(self, plans):
+ super(nnUNetTrainerV2_DP, self).process_plans(plans)
+ if not self.distribute_batch_size:
+ self.batch_size = self.num_gpus * self.plans['plans_per_stage'][self.stage]['batch_size']
+ else:
+ if self.batch_size < self.num_gpus:
+ print("WARNING: self.batch_size < self.num_gpus. Will not be able to use the GPUs well")
+ elif self.batch_size % self.num_gpus != 0:
+ print("WARNING: self.batch_size % self.num_gpus != 0. Will not be able to use the GPUs well")
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ - replaced get_default_augmentation with get_moreDA_augmentation
+ - only run this code once
+ - loss function wrapper for deep supervision
+
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here configure the loss for deep supervision ############
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.loss_weights = weights
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def initialize_network(self):
+ """
+ replace genericUNet with the implementation of above for super speeds
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet_DP(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ self.lr_scheduler = None
+
+ def run_training(self):
+ self.maybe_update_lr(self.epoch)
+
+ # amp must be initialized before DP
+
+ ds = self.network.do_ds
+ self.network.do_ds = True
+ self.network = DataParallel(self.network, tuple(range(self.num_gpus)), )
+ ret = nnUNetTrainer.run_training(self)
+ self.network = self.network.module
+ self.network.do_ds = ds
+ return ret
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data)
+ target = to_cuda(target)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation)
+ if run_online_evaluation:
+ ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
+ self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
+ else:
+ ces, tps, fps, fns = ret
+ del data, target
+ l = self.compute_loss(ces, tps, fps, fns)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ ret = self.network(data, target, return_hard_tp_fp_fn=run_online_evaluation)
+ if run_online_evaluation:
+ ces, tps, fps, fns, tp_hard, fp_hard, fn_hard = ret
+ self.run_online_evaluation(tp_hard, fp_hard, fn_hard)
+ else:
+ ces, tps, fps, fns = ret
+ del data, target
+ l = self.compute_loss(ces, tps, fps, fns)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ return l.detach().cpu().numpy()
+
+ def run_online_evaluation(self, tp_hard, fp_hard, fn_hard):
+ tp_hard = tp_hard.detach().cpu().numpy().mean(0)
+ fp_hard = fp_hard.detach().cpu().numpy().mean(0)
+ fn_hard = fn_hard.detach().cpu().numpy().mean(0)
+ self.online_eval_foreground_dc.append(list((2 * tp_hard) / (2 * tp_hard + fp_hard + fn_hard + 1e-8)))
+ self.online_eval_tp.append(list(tp_hard))
+ self.online_eval_fp.append(list(fp_hard))
+ self.online_eval_fn.append(list(fn_hard))
+
+ def compute_loss(self, ces, tps, fps, fns):
+ # we now need to effectively reimplement the loss
+ loss = None
+ for i in range(len(ces)):
+ if not self.dice_do_BG:
+ tp = tps[i][:, 1:]
+ fp = fps[i][:, 1:]
+ fn = fns[i][:, 1:]
+ else:
+ tp = tps[i]
+ fp = fps[i]
+ fn = fns[i]
+
+ if self.batch_dice:
+ tp = tp.sum(0)
+ fp = fp.sum(0)
+ fn = fn.sum(0)
+ else:
+ pass
+
+ nominator = 2 * tp + self.dice_smooth
+ denominator = 2 * tp + fp + fn + self.dice_smooth
+
+ dice_loss = (- nominator / denominator).mean()
+ if loss is None:
+ loss = self.loss_weights[i] * (ces[i].mean() + dice_loss)
+ else:
+ loss += self.loss_weights[i] * (ces[i].mean() + dice_loss)
+ ###########
+ return loss
\ No newline at end of file
diff --git a/nnunet/training/network_training/nnUNetTrainerV2_fp32.py b/nnunet/training/network_training/nnUNetTrainerV2_fp32.py
new file mode 100644
index 0000000000000000000000000000000000000000..58b7c2fbdfc55df3b2c46ee6acee7ab66694f455
--- /dev/null
+++ b/nnunet/training/network_training/nnUNetTrainerV2_fp32.py
@@ -0,0 +1,27 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_fp32(nnUNetTrainerV2):
+ """
+ Info for Fabian: same as internal nnUNetTrainerV2_2
+ """
+
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, False)
diff --git a/nnunet/training/network_training/nnUNet_variants/.DS_Store b/nnunet/training/network_training/nnUNet_variants/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..a5a9a4aa2a8a16ad7cff4149f8c438fbbb6becb1
Binary files /dev/null and b/nnunet/training/network_training/nnUNet_variants/.DS_Store differ
diff --git a/nnunet/training/network_training/nnUNet_variants/__init__.py b/nnunet/training/network_training/nnUNet_variants/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/__init__.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe1e8d4420351e6541e1bc9efb5d59ec158238dd
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_3ConvPerStage(nnUNetTrainerV2):
+ def initialize_network(self):
+ self.base_num_features = 24 # otherwise we run out of VRAM
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ 3, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage_samefilters.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage_samefilters.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5065988998e407c9a276ad85ae2e7d1a0ea935c
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_3ConvPerStage_samefilters.py
@@ -0,0 +1,45 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_3ConvPerStageSameFilters(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ 3, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_BN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_BN.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b77ab13446a9330f07b886e03a46f17471c6deb
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_BN.py
@@ -0,0 +1,55 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_BN(nnUNetTrainerV2):
+ def initialize_network(self):
+ """
+ changed deep supervision to False
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.BatchNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.BatchNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+
+nnUNetTrainerV2_BN_copy1 = nnUNetTrainerV2_BN
+nnUNetTrainerV2_BN_copy2 = nnUNetTrainerV2_BN
+nnUNetTrainerV2_BN_copy3 = nnUNetTrainerV2_BN
+nnUNetTrainerV2_BN_copy4 = nnUNetTrainerV2_BN
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_FRN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_FRN.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8cfa6822334567395725096b795fbc00070e5b0
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_FRN.py
@@ -0,0 +1,54 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.network_architecture.custom_modules.feature_response_normalization import FRN3D
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+from nnunet.network_architecture.custom_modules.helperModules import Identity
+import torch
+
+
+class nnUNetTrainerV2_FRN(nnUNetTrainerV2):
+ def initialize_network(self):
+ """
+ changed deep supervision to False
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = FRN3D
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ raise NotImplementedError
+ norm_op = nn.BatchNorm2d
+
+ norm_op_kwargs = {'eps': 1e-6}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = Identity
+ net_nonlin_kwargs = {}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GN.py
new file mode 100644
index 0000000000000000000000000000000000000000..27cfe29b59d7357b2fdca0edf0e1dd2bfc871be1
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GN.py
@@ -0,0 +1,50 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.network_architecture.custom_modules.helperModules import MyGroupNorm
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_GN(nnUNetTrainerV2):
+ def initialize_network(self):
+ """
+ changed deep supervision to False
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = MyGroupNorm
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = MyGroupNorm
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True, 'num_groups': 8}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GeLU.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GeLU.py
new file mode 100644
index 0000000000000000000000000000000000000000..16fb7f972d338a6fc2cf75c4930530f65908a03b
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_GeLU.py
@@ -0,0 +1,72 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+try:
+ from torch.nn.functional import gelu
+except ImportError:
+ gelu = None
+
+
+class GeLU(nn.Module):
+ def __init__(self):
+ super().__init__()
+ if gelu is None:
+ raise ImportError('You need to have at least torch==1.7.0 to use GeLUs')
+
+ def forward(self, x):
+ return gelu(x)
+
+
+class nnUNetTrainerV2_GeLU(nnUNetTrainerV2):
+ def initialize_network(self):
+ """
+ - momentum 0.99
+ - SGD instead of Adam
+ - self.lr_scheduler = None because we do poly_lr
+ - deep supervision = True
+ - ReLU
+ - i am sure I forgot something here
+
+ Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = GeLU
+ net_nonlin_kwargs = {}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_LReLU_slope_2en1.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_LReLU_slope_2en1.py
new file mode 100644
index 0000000000000000000000000000000000000000..23006409ff6847f42c8c37a1d76206e0dfd83fab
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_LReLU_slope_2en1.py
@@ -0,0 +1,45 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_LReLU_slope_2en1(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'inplace': True, 'negative_slope': 2e-1}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_Mish.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_Mish.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1f25df16733f1965692effce496023aa9271539
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_Mish.py
@@ -0,0 +1,47 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+from nnunet.network_architecture.custom_modules.mish import Mish
+
+
+class nnUNetTrainerV2_Mish(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = Mish
+ net_nonlin_kwargs = {}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d82b8e93995d04d88cf77101b3ab6eab677a4a8
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.network_architecture.custom_modules.helperModules import Identity
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_NoNormalization(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = Identity
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = Identity
+
+ norm_op_kwargs = {}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization_lr1en3.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization_lr1en3.py
new file mode 100644
index 0000000000000000000000000000000000000000..83173aee22652d4def3f130557a1714bb5071858
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_NoNormalization_lr1en3.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNet_variants.architectural_variants.nnUNetTrainerV2_NoNormalization import \
+ nnUNetTrainerV2_NoNormalization
+
+
+class nnUNetTrainerV2_NoNormalization_lr1en3(nnUNetTrainerV2_NoNormalization):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU.py
new file mode 100644
index 0000000000000000000000000000000000000000..97f290f1c9466b813d0c770055c3125085e0c977
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU.py
@@ -0,0 +1,45 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_ReLU(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.ReLU
+ net_nonlin_kwargs = {'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_biasInSegOutput.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_biasInSegOutput.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1a814c8cfef540fa496286eac26ff079696a75c
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_biasInSegOutput.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_ReLU_biasInSegOutput(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.ReLU
+ net_nonlin_kwargs = {'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True,
+ seg_output_use_bias=True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_convReLUIN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_convReLUIN.py
new file mode 100644
index 0000000000000000000000000000000000000000..a213116c9ae1e0f12bd3f513bac5baeeefb2d6c3
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ReLU_convReLUIN.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet, ConvDropoutNonlinNorm
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_ReLU_convReLUIN(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.ReLU
+ net_nonlin_kwargs = {'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True,
+ basic_block=ConvDropoutNonlinNorm)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ec323f7e75e7d7a91be4b1f1a08343a2ee065af
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet.py
@@ -0,0 +1,99 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+
+import numpy as np
+import torch
+from nnunet.network_architecture.generic_modular_residual_UNet import FabiansUNet, get_default_network_config
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_ResencUNet(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ cfg = get_default_network_config(3, None, norm_type="in")
+
+ else:
+ cfg = get_default_network_config(1, None, norm_type="in")
+
+ stage_plans = self.plans['plans_per_stage'][self.stage]
+ conv_kernel_sizes = stage_plans['conv_kernel_sizes']
+ blocks_per_stage_encoder = stage_plans['num_blocks_encoder']
+ blocks_per_stage_decoder = stage_plans['num_blocks_decoder']
+ pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
+
+ self.network = FabiansUNet(self.num_input_channels, self.base_num_features, blocks_per_stage_encoder, 2,
+ pool_op_kernel_sizes, conv_kernel_sizes, cfg, self.num_classes,
+ blocks_per_stage_decoder, True, False, 320, InitWeights_He(1e-2))
+
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ def setup_DA_params(self):
+ """
+ net_num_pool_op_kernel_sizes is different in resunet
+ """
+ super().setup_DA_params()
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes[1:]), axis=0))[:-1]
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ force_separate_z: bool = None, interpolation_order: int = 3, interpolation_order_z=0,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = False
+ ret = nnUNetTrainer.validate(self, do_mirroring=do_mirroring, use_sliding_window=use_sliding_window,
+ step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name,
+ debug=debug, all_in_gpu=all_in_gpu,
+ segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ self.network.decoder.deep_supervision = ds
+ return ret
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = False
+ ret = nnUNetTrainer.predict_preprocessed_data_return_seg_and_softmax(self, data, do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ pad_border_mode=pad_border_mode,
+ pad_kwargs=pad_kwargs,
+ all_in_gpu=all_in_gpu,
+ verbose=verbose,
+ mixed_precision=mixed_precision)
+ self.network.decoder.deep_supervision = ds
+ return ret
+
+ def run_training(self):
+ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
+ # want at the start of the training
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = True
+ ret = nnUNetTrainer.run_training(self)
+ self.network.decoder.deep_supervision = ds
+ return ret
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3.py
new file mode 100644
index 0000000000000000000000000000000000000000..11e48b188a948d4a4ef526d88c1f95a7a229617a
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3.py
@@ -0,0 +1,104 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+
+import numpy as np
+import torch
+
+from nnunet.network_architecture.generic_modular_residual_UNet import FabiansUNet, get_default_network_config
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNet_variants.data_augmentation.nnUNetTrainerV2_DA3 import \
+ nnUNetTrainerV2_DA3
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_ResencUNet_DA3(nnUNetTrainerV2_DA3):
+ def initialize_network(self):
+ if self.threeD:
+ cfg = get_default_network_config(3, None, norm_type="in")
+
+ else:
+ cfg = get_default_network_config(1, None, norm_type="in")
+
+ stage_plans = self.plans['plans_per_stage'][self.stage]
+ conv_kernel_sizes = stage_plans['conv_kernel_sizes']
+ blocks_per_stage_encoder = stage_plans['num_blocks_encoder']
+ blocks_per_stage_decoder = stage_plans['num_blocks_decoder']
+ pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
+
+ self.network = FabiansUNet(self.num_input_channels, self.base_num_features, blocks_per_stage_encoder, 2,
+ pool_op_kernel_sizes, conv_kernel_sizes, cfg, self.num_classes,
+ blocks_per_stage_decoder, True, False, 320, InitWeights_He(1e-2))
+
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ def setup_DA_params(self):
+ """
+ net_num_pool_op_kernel_sizes is different in resunet
+ """
+ super().setup_DA_params()
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes[1:]), axis=0))[:-1]
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = False
+
+ ret = nnUNetTrainer.validate(self, do_mirroring=do_mirroring, use_sliding_window=use_sliding_window,
+ step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+
+ self.network.decoder.deep_supervision = ds
+ return ret
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = False
+ ret = nnUNetTrainer.predict_preprocessed_data_return_seg_and_softmax(self, data=data,
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ pad_border_mode=pad_border_mode,
+ pad_kwargs=pad_kwargs,
+ all_in_gpu=all_in_gpu,
+ verbose=verbose,
+ mixed_precision=mixed_precision)
+ self.network.decoder.deep_supervision = ds
+ return ret
+
+ def run_training(self):
+ self.maybe_update_lr(self.epoch) # if we dont overwrite epoch then self.epoch+1 is used which is not what we
+ # want at the start of the training
+ ds = self.network.decoder.deep_supervision
+ self.network.decoder.deep_supervision = True
+ ret = nnUNetTrainer.run_training(self)
+ self.network.decoder.deep_supervision = ds
+ return ret
+
+
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3_BN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3_BN.py
new file mode 100644
index 0000000000000000000000000000000000000000..a02ad367a396c1cb5534c1d1b7151071e61b28fc
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_ResencUNet_DA3_BN.py
@@ -0,0 +1,44 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+from nnunet.network_architecture.generic_modular_residual_UNet import FabiansUNet, get_default_network_config
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNet_variants.architectural_variants.nnUNetTrainerV2_ResencUNet_DA3 import \
+ nnUNetTrainerV2_ResencUNet_DA3
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_ResencUNet_DA3_BN(nnUNetTrainerV2_ResencUNet_DA3):
+ def initialize_network(self):
+ if self.threeD:
+ cfg = get_default_network_config(3, None, norm_type="bn")
+
+ else:
+ cfg = get_default_network_config(1, None, norm_type="bn")
+
+ stage_plans = self.plans['plans_per_stage'][self.stage]
+ conv_kernel_sizes = stage_plans['conv_kernel_sizes']
+ blocks_per_stage_encoder = stage_plans['num_blocks_encoder']
+ blocks_per_stage_decoder = stage_plans['num_blocks_decoder']
+ pool_op_kernel_sizes = stage_plans['pool_op_kernel_sizes']
+
+ self.network = FabiansUNet(self.num_input_channels, self.base_num_features, blocks_per_stage_encoder, 2,
+ pool_op_kernel_sizes, conv_kernel_sizes, cfg, self.num_classes,
+ blocks_per_stage_decoder, True, False, 320, InitWeights_He(1e-2))
+
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_allConv3x3.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_allConv3x3.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9ab2a00fa9236b673e19e35234c44a022f03c5
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_allConv3x3.py
@@ -0,0 +1,60 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_allConv3x3(nnUNetTrainerV2):
+ def initialize_network(self):
+ """
+ - momentum 0.99
+ - SGD instead of Adam
+ - self.lr_scheduler = None because we do poly_lr
+ - deep supervision = True
+ - i am sure I forgot something here
+
+ Known issue: forgot to set neg_slope=0 in InitWeights_He; should not make a difference though
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ for s in range(len(self.net_conv_kernel_sizes)):
+ for i in range(len(self.net_conv_kernel_sizes[s])):
+ self.net_conv_kernel_sizes[s][i] = 3
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_biasInSegOutput.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_biasInSegOutput.py
new file mode 100644
index 0000000000000000000000000000000000000000..85f8237a32138df781716bf839d356a6ceb6aa2e
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_biasInSegOutput.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_lReLU_biasInSegOutput(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(0),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True,
+ seg_output_use_bias=True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_convlReLUIN.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_convlReLUIN.py
new file mode 100644
index 0000000000000000000000000000000000000000..198f9ef43eedeaa0db5d673e77407d65738fcb9c
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_lReLU_convlReLUIN.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import torch
+from nnunet.network_architecture.generic_UNet import Generic_UNet, ConvDropoutNonlinNorm
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_lReLU_convReLUIN(nnUNetTrainerV2):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'inplace': True, 'negative_slope': 1e-2}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True,
+ basic_block=ConvDropoutNonlinNorm)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_noDeepSupervision.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_noDeepSupervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..59254abc88a08df29daaaaede99aae9c6694deb9
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_noDeepSupervision.py
@@ -0,0 +1,164 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
+ default_2D_augmentation_params, get_patch_size
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+import torch
+
+
+class nnUNetTrainerV2_noDeepSupervision(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
+
+ def setup_DA_params(self):
+ """
+ we leave out the creation of self.deep_supervision_scales, so it remains None
+ :return:
+ """
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-15. / 360 * 2. * np.pi, 15. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params["scale_range"] = (0.7, 1.4)
+ self.data_aug_params["do_elastic"] = False
+ self.data_aug_params['selected_seg_channels'] = [0]
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ removed deep supervision
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ assert self.deep_supervision_scales is None
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ classes=None,
+ pin_memory=self.pin_memory)
+
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def initialize_network(self):
+ """
+ changed deep supervision to False
+ :return:
+ """
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.InstanceNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.InstanceNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op, dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, False, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
+
+ def run_online_evaluation(self, output, target):
+ return nnUNetTrainer.run_online_evaluation(self, output, target)
diff --git a/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_softDeepSupervision.py b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_softDeepSupervision.py
new file mode 100644
index 0000000000000000000000000000000000000000..62f974a0e82b0313656fa9d56c8c0f9329c4cb0d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/architectural_variants/nnUNetTrainerV2_softDeepSupervision.py
@@ -0,0 +1,127 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p
+from nnunet.training.data_augmentation.data_augmentation_moreDA import get_moreDA_augmentation
+
+try:
+ from meddec.model_training.ablation_studies.new_nnUNet_candidates.nnUNetTrainerCandidate23_softDeepSupervision4 import \
+ MyDSLoss4
+except ImportError:
+ MyDSLoss4 = None
+
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch import nn
+import numpy as np
+
+
+class nnUNetTrainerV2_softDeepSupervision(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = None # we take care of that later
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ - replaced get_default_augmentation with get_moreDA_augmentation
+ - only run this code once
+ - loss function wrapper for deep supervision
+
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+
+ # now wrap the loss
+ if MyDSLoss4 is None:
+ raise RuntimeError("This aint ready for prime time yet")
+
+ self.loss = MyDSLoss4(self.batch_dice, weights)
+ #self.loss = MultipleOutputLoss2(self.loss, weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ soft_ds=True, classes=[0] + list(self.classes),
+ pin_memory=self.pin_memory)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def run_online_evaluation(self, output, target):
+ """
+ due to deep supervision the return value and the reference are now lists of tensors. We only need the full
+ resolution output because this is what we are interested in in the end. The others are ignored
+ :param output:
+ :param target:
+ :return:
+ """
+ target = target[0][:,
+ None] # we need to restore color channel dimension here to be compatible with previous code
+ output = output[0]
+ return nnUNetTrainer.run_online_evaluation(self, output, target)
diff --git a/nnunet/training/network_training/nnUNet_variants/benchmarking/__init__.py b/nnunet/training/network_training/nnUNet_variants/benchmarking/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_2epochs.py b/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_2epochs.py
new file mode 100644
index 0000000000000000000000000000000000000000..296ca7f4e6261e542ec060080df2ca5b19ce7a16
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_2epochs.py
@@ -0,0 +1,293 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+import numpy as np
+import torch
+
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.network_training.nnUNetTrainerV2_DDP import nnUNetTrainerV2_DDP
+from nnunet.training.network_training.nnUNet_variants.architectural_variants.nnUNetTrainerV2_noDeepSupervision import \
+ nnUNetTrainerV2_noDeepSupervision
+from nnunet.utilities.to_torch import maybe_to_torch, to_cuda
+from torch.cuda.amp import autocast
+
+
+class nnUNetTrainerV2_2epochs(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 2
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+
+class nnUNetTrainerV2_5epochs(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 5
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+
+class nnUNetTrainerV2_5epochs_CEnoDS(nnUNetTrainerV2_noDeepSupervision):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 5
+ self.loss = RobustCrossEntropyLoss()
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target).long()[:, 0]
+
+ if torch.cuda.is_available():
+ data = to_cuda(data)
+ target = to_cuda(target)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def run_online_evaluation(self, output, target):
+ pass
+
+ def finish_online_evaluation(self):
+ pass
+
+
+class nnUNetTrainerV2_5epochs_noDS(nnUNetTrainerV2_noDeepSupervision):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 5
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data_dict = next(data_generator)
+ data = data_dict['data']
+ target = data_dict['target']
+
+ data = maybe_to_torch(data)
+ target = maybe_to_torch(target)
+
+ if torch.cuda.is_available():
+ data = to_cuda(data)
+ target = to_cuda(target)
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def run_online_evaluation(self, output, target):
+ pass
+
+ def finish_online_evaluation(self):
+ pass
+
+
+class nnUNetTrainerV2_DDP_5epochs(nnUNetTrainerV2_DDP):
+ def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True,
+ stage=None,
+ unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False):
+ super().__init__(plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, distribute_batch_size, fp16)
+ self.max_num_epochs = 5
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+
+class nnUNetTrainerV2_DDP_5epochs_dummyLoad(nnUNetTrainerV2_DDP_5epochs):
+ def initialize(self, training=True, force_load_plans=False):
+ super().initialize(training, force_load_plans)
+ self.some_batch = torch.rand((self.batch_size, self.num_input_channels, *self.patch_size)).float().cuda()
+
+ self.some_gt = [torch.round(torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(self.patch_size, k)])) * (
+ self.num_classes - 1)).float().cuda() for k in self.deep_supervision_scales]
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data = self.some_batch
+ target = self.some_gt
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.compute_loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.compute_loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
\ No newline at end of file
diff --git a/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_dummyLoad.py b/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_dummyLoad.py
new file mode 100644
index 0000000000000000000000000000000000000000..355857a9eabd0f5f2201bf3923e343a3965956d4
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/benchmarking/nnUNetTrainerV2_dummyLoad.py
@@ -0,0 +1,147 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+
+import torch
+
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.network_training.nnUNet_variants.architectural_variants.nnUNetTrainerV2_noDeepSupervision import \
+ nnUNetTrainerV2_noDeepSupervision
+from nnunet.training.network_training.nnUNet_variants.benchmarking.nnUNetTrainerV2_2epochs import nnUNetTrainerV2_5epochs
+from torch.cuda.amp import autocast
+import numpy as np
+
+
+class nnUNetTrainerV2_5epochs_dummyLoad(nnUNetTrainerV2_5epochs):
+ def initialize(self, training=True, force_load_plans=False):
+ super().initialize(training, force_load_plans)
+ self.some_batch = torch.rand((self.batch_size, self.num_input_channels, *self.patch_size)).float().cuda()
+
+ self.some_gt = [torch.round(torch.rand((self.batch_size, 1, *[int(i * j) for i, j in zip(self.patch_size, k)])) * (self.num_classes - 1)).float().cuda() for k in self.deep_supervision_scales]
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data = self.some_batch
+ target = self.some_gt
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+
+class nnUNetTrainerV2_2epochs_dummyLoad(nnUNetTrainerV2_5epochs_dummyLoad):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 2
+
+
+class nnUNetTrainerV2_5epochs_dummyLoadCEnoDS(nnUNetTrainerV2_noDeepSupervision):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 5
+ self.loss = RobustCrossEntropyLoss()
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: float = 0.5,
+ save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs=None, run_postprocessing_on_folds: bool = True):
+ pass
+
+ def predict_preprocessed_data_return_seg_and_softmax(self, data: np.ndarray, do_mirroring: bool = True,
+ mirror_axes: Tuple[int] = None,
+ use_sliding_window: bool = True, step_size: float = 0.5,
+ use_gaussian: bool = True, pad_border_mode: str = 'constant',
+ pad_kwargs: dict = None, all_in_gpu: bool = False,
+ verbose: bool = True, mixed_precision=True) -> Tuple[np.ndarray, np.ndarray]:
+ pass
+
+ def save_checkpoint(self, fname, save_optimizer=True):
+ pass
+
+ def initialize(self, training=True, force_load_plans=False):
+ super().initialize(training, force_load_plans)
+ self.some_batch = torch.rand((self.batch_size, self.num_input_channels, *self.patch_size)).float().cuda()
+
+ self.some_gt = torch.round(torch.rand((self.batch_size, *self.patch_size)) * (self.num_classes - 1)).long().cuda()
+
+ def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False):
+ data = self.some_batch
+ target = self.some_gt
+
+ self.optimizer.zero_grad()
+
+ if self.fp16:
+ with autocast():
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ self.amp_grad_scaler.scale(l).backward()
+ self.amp_grad_scaler.unscale_(self.optimizer)
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.amp_grad_scaler.step(self.optimizer)
+ self.amp_grad_scaler.update()
+ else:
+ output = self.network(data)
+ del data
+ l = self.loss(output, target)
+
+ if do_backprop:
+ l.backward()
+ torch.nn.utils.clip_grad_norm_(self.network.parameters(), 12)
+ self.optimizer.step()
+
+ if run_online_evaluation:
+ self.run_online_evaluation(output, target)
+
+ del target
+
+ return l.detach().cpu().numpy()
+
+ def run_online_evaluation(self, output, target):
+ pass
+
+ def finish_online_evaluation(self):
+ pass
diff --git a/nnunet/training/network_training/nnUNet_variants/cascade/__init__.py b/nnunet/training/network_training/nnUNet_variants/cascade/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_DAVariants.py b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_DAVariants.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e5275d16d08611484b329a86da211d22715105c
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_DAVariants.py
@@ -0,0 +1,87 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+
+
+class nnUNetTrainerV2CascadeFullRes_noConnComp(nnUNetTrainerV2CascadeFullRes):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.4
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 8)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.0
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+
+class nnUNetTrainerV2CascadeFullRes_smallerBinStrel(nnUNetTrainerV2CascadeFullRes):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.4
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 1
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 5)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.15
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+
+class nnUNetTrainerV2CascadeFullRes_EducatedGuess(nnUNetTrainerV2CascadeFullRes):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.5
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 0.5
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 5)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.2
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.10
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+
+class nnUNetTrainerV2CascadeFullRes_EducatedGuess2(nnUNetTrainerV2CascadeFullRes):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 0.5
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 0.5
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 5)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.0
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.10
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
+
+class nnUNetTrainerV2CascadeFullRes_EducatedGuess3(nnUNetTrainerV2CascadeFullRes):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params['cascade_do_cascade_augmentations'] = True
+
+ self.data_aug_params['cascade_random_binary_transform_p'] = 1
+ self.data_aug_params['cascade_random_binary_transform_p_per_label'] = 0.33
+ self.data_aug_params['cascade_random_binary_transform_size'] = (1, 5)
+
+ self.data_aug_params['cascade_remove_conn_comp_p'] = 0.0
+ self.data_aug_params['cascade_remove_conn_comp_max_size_percent_threshold'] = 0.10
+ self.data_aug_params['cascade_remove_conn_comp_fill_with_other_class_p'] = 0.0
+
diff --git a/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_lowerLR.py b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_lowerLR.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ca7ee03f9dfcc9acdeb69abd764c3ab960ab740
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_lowerLR.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+
+
+class nnUNetTrainerV2CascadeFullRes_lowerLR(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic,
+ previous_trainer, fp16)
+ self.initial_lr = 1e-3
diff --git a/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter.py b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter.py
new file mode 100644
index 0000000000000000000000000000000000000000..26da3d5a294daecdc5e1040cb1f24286ad11fb1d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+
+
+class nnUNetTrainerV2CascadeFullRes_shorter(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic,
+ previous_trainer, fp16)
+ self.max_num_epochs = 500
diff --git a/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter_lowerLR.py b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter_lowerLR.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c870735d14a0232a8c723c79f696cbb058c0ee6
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/cascade/nnUNetTrainerV2CascadeFullRes_shorter_lowerLR.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2_CascadeFullRes import nnUNetTrainerV2CascadeFullRes
+
+
+class nnUNetTrainerV2CascadeFullRes_shorter_lowerLR(nnUNetTrainerV2CascadeFullRes):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, previous_trainer="nnUNetTrainerV2", fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory,
+ batch_dice, stage, unpack_data, deterministic,
+ previous_trainer, fp16)
+ self.max_num_epochs = 500
+ self.initial_lr = 1e-3
diff --git a/nnunet/training/network_training/nnUNet_variants/copies/__init__.py b/nnunet/training/network_training/nnUNet_variants/copies/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/copies/nnUNetTrainerV2_copies.py b/nnunet/training/network_training/nnUNet_variants/copies/nnUNetTrainerV2_copies.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ec8ffc1e5146a3fb13e4b723eca2b18a0bdf8af
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/copies/nnUNetTrainerV2_copies.py
@@ -0,0 +1,49 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+# This stuff is just so that we can check stability of results. Training is nondeterministic and by renaming the trainer
+# class we can have several trained models coexist although the trainer is effectively the same
+
+
+class nnUNetTrainerV2_copy1(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+
+class nnUNetTrainerV2_copy2(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+
+class nnUNetTrainerV2_copy3(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+
+class nnUNetTrainerV2_copy4(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/__init__.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA2.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA2.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff517ed2647f224049219cdc171a3f20476babf7
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA2.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_DA2(nnUNetTrainerV2):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+
+ self.data_aug_params["independent_scale_factor_for_each_axis"] = True
+
+ if self.threeD:
+ self.data_aug_params["rotation_p_per_axis"] = 0.5
+ else:
+ self.data_aug_params["rotation_p_per_axis"] = 1
+
+ self.data_aug_params["do_additive_brightness"] = True
+
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA3.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA3.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf282e5c9a0c4e840d34517d2e95c3821ea7f535
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA3.py
@@ -0,0 +1,190 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import numpy as np
+import torch
+from batchgenerators.utilities.file_and_folder_operations import join
+from nnunet.network_architecture.generic_UNet import Generic_UNet
+from nnunet.network_architecture.initialization import InitWeights_He
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.data_augmentation_insaneDA2 import get_insaneDA_augmentation2
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
+ default_2D_augmentation_params, get_patch_size
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2, maybe_mkdir_p
+from nnunet.utilities.nd_softmax import softmax_helper
+from torch import nn
+
+
+class nnUNetTrainerV2_DA3(nnUNetTrainerV2):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
+
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params['selected_seg_channels'] = [0]
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ self.data_aug_params["p_rot"] = 0.3
+
+ self.data_aug_params["scale_range"] = (0.65, 1.6)
+ self.data_aug_params["p_scale"] = 0.3
+ self.data_aug_params["independent_scale_factor_for_each_axis"] = True
+ self.data_aug_params["p_independent_scale_per_axis"] = 0.3
+
+ self.data_aug_params["do_elastic"] = True
+ self.data_aug_params["p_eldef"] = 0.3
+ self.data_aug_params["eldef_deformation_scale"] = (0, 0.25)
+
+ self.data_aug_params["do_additive_brightness"] = True
+ self.data_aug_params["additive_brightness_mu"] = 0
+ self.data_aug_params["additive_brightness_sigma"] = 0.2
+ self.data_aug_params["additive_brightness_p_per_sample"] = 0.3
+ self.data_aug_params["additive_brightness_p_per_channel"] = 1
+
+ self.data_aug_params['gamma_range'] = (0.5, 1.6)
+
+ self.data_aug_params['num_cached_per_thread'] = 4
+
+ def initialize(self, training=True, force_load_plans=False):
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_insaneDA_augmentation2(
+ self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory
+ )
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ """def run_training(self):
+ from batchviewer import view_batch
+
+ a = next(self.tr_gen)
+ view_batch(a['data'][:, 0], width=512, height=512)
+
+ import IPython;IPython.embed()"""
+
+
+class nnUNetTrainerV2_DA3_BN(nnUNetTrainerV2_DA3):
+ def initialize_network(self):
+ if self.threeD:
+ conv_op = nn.Conv3d
+ dropout_op = nn.Dropout3d
+ norm_op = nn.BatchNorm3d
+
+ else:
+ conv_op = nn.Conv2d
+ dropout_op = nn.Dropout2d
+ norm_op = nn.BatchNorm2d
+
+ norm_op_kwargs = {'eps': 1e-5, 'affine': True}
+ dropout_op_kwargs = {'p': 0, 'inplace': True}
+ net_nonlin = nn.LeakyReLU
+ net_nonlin_kwargs = {'negative_slope': 1e-2, 'inplace': True}
+ self.network = Generic_UNet(self.num_input_channels, self.base_num_features, self.num_classes,
+ len(self.net_num_pool_op_kernel_sizes),
+ self.conv_per_stage, 2, conv_op, norm_op, norm_op_kwargs, dropout_op,
+ dropout_op_kwargs,
+ net_nonlin, net_nonlin_kwargs, True, False, lambda x: x, InitWeights_He(1e-2),
+ self.net_num_pool_op_kernel_sizes, self.net_conv_kernel_sizes, False, True, True)
+ if torch.cuda.is_available():
+ self.network.cuda()
+ self.network.inference_apply_nonlin = softmax_helper
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA5.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA5.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8bbf4d4c01d88875ae309d97e87887fce9c2675
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_DA5.py
@@ -0,0 +1,429 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import List
+
+import numpy as np
+from batchgenerators.dataloading.nondet_multi_threaded_augmenter import NonDetMultiThreadedAugmenter
+from batchgenerators.transforms.abstract_transforms import AbstractTransform, Compose
+from batchgenerators.transforms.channel_selection_transforms import SegChannelSelectionTransform
+from batchgenerators.transforms.color_transforms import BrightnessTransform, ContrastAugmentationTransform, \
+ GammaTransform
+from batchgenerators.transforms.local_transforms import BrightnessGradientAdditiveTransform, LocalGammaTransform
+from batchgenerators.transforms.noise_transforms import BlankRectangleTransform, MedianFilterTransform, \
+ SharpeningTransform
+from batchgenerators.transforms.noise_transforms import GaussianNoiseTransform, GaussianBlurTransform
+from batchgenerators.transforms.resample_transforms import SimulateLowResolutionTransform
+from batchgenerators.transforms.spatial_transforms import Rot90Transform, TransposeAxesTransform, MirrorTransform
+from batchgenerators.transforms.spatial_transforms import SpatialTransform
+from batchgenerators.transforms.utility_transforms import RemoveLabelTransform, RenameTransform, NumpyToTensor, \
+ OneOfTransform
+from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join
+from torch import nn
+
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.custom_transforms import Convert3DTo2DTransform, Convert2DTo3DTransform, \
+ MaskTransform, ConvertSegmentationToRegionsTransform
+from nnunet.training.data_augmentation.default_data_augmentation import get_patch_size
+from nnunet.training.data_augmentation.downsampling import DownsampleSegForDSTransform2
+from nnunet.training.data_augmentation.pyramid_augmentations import MoveSegAsOneHotToData, \
+ ApplyRandomBinaryOperatorTransform, RemoveRandomConnectedComponentFromOneHotEncodingTransform
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.utilities.set_n_proc_DA import get_allowed_n_proc_DA
+
+
+class nnUNetTrainerV2_DA5(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.do_mirroring = True
+ self.mirror_axes = None
+ proc = get_allowed_n_proc_DA()
+ self.num_proc_DA = proc if proc is not None else 12
+ self.num_cached = 4
+ self.regions_class_order = self.regions = None
+
+ def setup_DA_params(self):
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
+
+ self.data_aug_params = dict()
+ self.data_aug_params['scale_range'] = (0.7, 1.43)
+
+ # we need this because this is adapted in the cascade
+ self.data_aug_params['selected_seg_channels'] = None
+ self.data_aug_params["move_last_seg_chanel_to_data"] = False
+
+ if self.threeD:
+ if self.do_mirroring:
+ self.mirror_axes = (0, 1, 2)
+ self.data_aug_params['do_mirror'] = True # needed for inference
+ self.data_aug_params['mirror_axes'] = (0, 1, 2) # needed for inference
+ else:
+ self.data_aug_params['mirror_axes'] = tuple()
+ self.data_aug_params['do_mirror'] = False
+
+ self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+
+ if self.do_dummy_2D_aug:
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["dummy_2D"] = True
+ self.data_aug_params["rotation_x"] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+ else:
+ if self.do_mirroring:
+ self.mirror_axes = (0, 1)
+ self.data_aug_params['mirror_axes'] = (0, 1) # needed for inference
+ self.data_aug_params['do_mirror'] = True # needed for inference
+ else:
+ self.data_aug_params['mirror_axes'] = tuple()
+ self.data_aug_params['do_mirror'] = False # needed for inference
+
+
+ self.do_dummy_2D_aug = False
+
+ self.data_aug_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-0. / 360 * 2. * np.pi, 0. / 360 * 2. * np.pi)
+
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ def get_train_transforms(self) -> List[AbstractTransform]:
+ # used for transpost and rot90
+ matching_axes = np.array([sum([i == j for j in self.patch_size]) for i in self.patch_size])
+ valid_axes = list(np.where(matching_axes == np.max(matching_axes))[0])
+
+ tr_transforms = []
+
+ if self.data_aug_params['selected_seg_channels'] is not None:
+ tr_transforms.append(SegChannelSelectionTransform(self.data_aug_params['selected_seg_channels']))
+
+ if self.do_dummy_2D_aug:
+ ignore_axes = (0,)
+ tr_transforms.append(Convert3DTo2DTransform())
+ patch_size_spatial = self.patch_size[1:]
+ else:
+ patch_size_spatial = self.patch_size
+ ignore_axes = None
+
+ tr_transforms.append(
+ SpatialTransform(
+ patch_size_spatial,
+ patch_center_dist_from_border=None,
+ do_elastic_deform=False,
+ do_rotation=True,
+ angle_x=self.data_aug_params["rotation_x"],
+ angle_y=self.data_aug_params["rotation_y"],
+ angle_z=self.data_aug_params["rotation_z"],
+ p_rot_per_axis=0.5,
+ do_scale=True,
+ scale=self.data_aug_params['scale_range'],
+ border_mode_data="constant",
+ border_cval_data=0,
+ order_data=3,
+ border_mode_seg="constant",
+ border_cval_seg=-1,
+ order_seg=1,
+ random_crop=False,
+ p_el_per_sample=0.2,
+ p_scale_per_sample=0.2,
+ p_rot_per_sample=0.4,
+ independent_scale_for_each_axis=True,
+ )
+ )
+
+ if self.do_dummy_2D_aug:
+ tr_transforms.append(Convert2DTo3DTransform())
+
+ if np.any(matching_axes > 1):
+ tr_transforms.append(
+ Rot90Transform(
+ (0, 1, 2, 3), axes=valid_axes, data_key='data', label_key='seg', p_per_sample=0.5
+ ),
+ )
+
+ if np.any(matching_axes > 1):
+ tr_transforms.append(
+ TransposeAxesTransform(valid_axes, data_key='data', label_key='seg', p_per_sample=0.5)
+ )
+
+ tr_transforms.append(OneOfTransform([
+ MedianFilterTransform(
+ (2, 8),
+ same_for_each_channel=False,
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ GaussianBlurTransform((0.3, 1.5),
+ different_sigma_per_channel=True,
+ p_per_sample=0.2,
+ p_per_channel=0.5)
+ ]))
+
+ tr_transforms.append(GaussianNoiseTransform(p_per_sample=0.1))
+
+ tr_transforms.append(BrightnessTransform(0,
+ 0.5,
+ per_channel=True,
+ p_per_sample=0.1,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(OneOfTransform(
+ [
+ ContrastAugmentationTransform(
+ contrast_range=(0.5, 2),
+ preserve_range=True,
+ per_channel=True,
+ data_key='data',
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ ContrastAugmentationTransform(
+ contrast_range=(0.5, 2),
+ preserve_range=False,
+ per_channel=True,
+ data_key='data',
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ ),
+ ]
+ ))
+
+ tr_transforms.append(
+ SimulateLowResolutionTransform(zoom_range=(0.25, 1),
+ per_channel=True,
+ p_per_channel=0.5,
+ order_downsample=0,
+ order_upsample=3,
+ p_per_sample=0.15,
+ ignore_axes=ignore_axes
+ )
+ )
+
+ tr_transforms.append(
+ GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
+ tr_transforms.append(
+ GammaTransform((0.7, 1.5), invert_image=True, per_channel=True, retain_stats=True, p_per_sample=0.1))
+
+ if self.do_mirroring:
+ tr_transforms.append(MirrorTransform(self.mirror_axes))
+
+ tr_transforms.append(
+ BlankRectangleTransform([[max(1, p // 10), p // 3] for p in self.patch_size],
+ rectangle_value=np.mean,
+ num_rectangles=(1, 5),
+ force_square=False,
+ p_per_sample=0.4,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ BrightnessGradientAdditiveTransform(
+ lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
+ (-0.5, 1.5),
+ max_strength=lambda x, y: np.random.uniform(-5, -1) if np.random.uniform() < 0.5 else np.random.uniform(1, 5),
+ mean_centered=False,
+ same_for_all_channels=False,
+ p_per_sample=0.3,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ LocalGammaTransform(
+ lambda x, y: np.exp(np.random.uniform(np.log(x[y] // 6), np.log(x[y]))),
+ (-0.5, 1.5),
+ lambda: np.random.uniform(0.01, 0.8) if np.random.uniform() < 0.5 else np.random.uniform(1.5, 4),
+ same_for_all_channels=False,
+ p_per_sample=0.3,
+ p_per_channel=0.5
+ )
+ )
+
+ tr_transforms.append(
+ SharpeningTransform(
+ strength=(0.1, 1),
+ same_for_each_channel=False,
+ p_per_sample=0.2,
+ p_per_channel=0.5
+ )
+ )
+
+ if any(self.use_mask_for_norm.values()):
+ tr_transforms.append(MaskTransform(self.use_mask_for_norm, mask_idx_in_seg=0, set_outside_to=0))
+
+ tr_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if self.data_aug_params["move_last_seg_chanel_to_data"]:
+ all_class_labels = np.arange(1, self.num_classes)
+ tr_transforms.append(MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data'))
+ if self.data_aug_params["cascade_do_cascade_augmentations"]:
+ tr_transforms.append(
+ ApplyRandomBinaryOperatorTransform(
+ channel_idx=list(range(-len(all_class_labels), 0)),
+ p_per_sample=0.4,
+ key="data",
+ strel_size=(1, 8),
+ p_per_label=1
+ )
+ )
+
+ tr_transforms.append(
+ RemoveRandomConnectedComponentFromOneHotEncodingTransform(
+ channel_idx=list(range(-len(all_class_labels), 0)),
+ key="data",
+ p_per_sample=0.2,
+ fill_with_other_class_p=0.15,
+ dont_do_if_covers_more_than_X_percent=0
+ )
+ )
+
+ tr_transforms.append(RenameTransform('seg', 'target', True))
+
+ if self.regions is not None:
+ tr_transforms.append(ConvertSegmentationToRegionsTransform(self.regions, 'target', 'target'))
+
+ if self.deep_supervision_scales is not None:
+ tr_transforms.append(
+ DownsampleSegForDSTransform2(self.deep_supervision_scales, 0, input_key='target',
+ output_key='target')
+ )
+
+ tr_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ return tr_transforms
+
+ def get_val_transforms(self) -> List[AbstractTransform]:
+ val_transforms = list()
+ val_transforms.append(RemoveLabelTransform(-1, 0))
+
+ if self.data_aug_params['selected_seg_channels'] is not None:
+ val_transforms.append(SegChannelSelectionTransform(self.data_aug_params['selected_seg_channels']))
+
+ if self.data_aug_params["move_last_seg_chanel_to_data"]:
+ all_class_labels = np.arange(1, self.num_classes)
+ val_transforms.append(MoveSegAsOneHotToData(1, all_class_labels, 'seg', 'data'))
+ val_transforms.append(RenameTransform('seg', 'target', True))
+
+ if self.regions is not None:
+ val_transforms.append(ConvertSegmentationToRegionsTransform(self.regions, 'target', 'target'))
+
+ if self.deep_supervision_scales is not None:
+ val_transforms.append(
+ DownsampleSegForDSTransform2(
+ self.deep_supervision_scales, 0, input_key='target',
+ output_key='target')
+ )
+
+ val_transforms.append(NumpyToTensor(['data', 'target'], 'float'))
+ return val_transforms
+
+ def wrap_transforms(self, dataloader_train, dataloader_val, train_transforms, val_transforms):
+ tr_gen = NonDetMultiThreadedAugmenter(dataloader_train,
+ Compose(train_transforms),
+ self.num_proc_DA,
+ self.num_cached,
+ seeds=None,
+ pin_memory=self.pin_memory)
+ val_gen = NonDetMultiThreadedAugmenter(dataloader_val,
+ Compose(val_transforms),
+ self.num_proc_DA // 2,
+ self.num_cached,
+ seeds=None,
+ pin_memory=self.pin_memory)
+ return tr_gen, val_gen
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ replace DA
+ :param training:
+ :param force_load_plans:
+ :return:
+ """
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True] + [True if i < net_numpool - 1 else False for i in range(1, net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+ self.ds_loss_weights = weights
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ tr_transforms = self.get_train_transforms()
+ val_transforms = self.get_val_transforms()
+ self.tr_gen, self.val_gen = self.wrap_transforms(self.dl_tr, self.dl_val, tr_transforms, val_transforms)
+
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+ self.was_initialized = True
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_independentScalePerAxis.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_independentScalePerAxis.py
new file mode 100644
index 0000000000000000000000000000000000000000..aacb7ab09bcf2ffbb310e2bc31b2ff3ac15b2069
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_independentScalePerAxis.py
@@ -0,0 +1,22 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_independentScalePerAxis(nnUNetTrainerV2):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params["independent_scale_factor_for_each_axis"] = True
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_insaneDA.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_insaneDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..634b20b81c8d1faa6c36089ec816a9a89f809322
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_insaneDA.py
@@ -0,0 +1,139 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.data_augmentation_insaneDA import get_insaneDA_augmentation
+from nnunet.training.data_augmentation.default_data_augmentation import default_3D_augmentation_params, \
+ default_2D_augmentation_params, get_patch_size
+from nnunet.training.dataloading.dataset_loading import unpack_dataset
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch import nn
+
+
+class nnUNetTrainerV2_insaneDA(nnUNetTrainerV2):
+ def setup_DA_params(self):
+ self.deep_supervision_scales = [[1, 1, 1]] + list(list(i) for i in 1 / np.cumprod(
+ np.vstack(self.net_num_pool_op_kernel_sizes), axis=0))[:-1]
+
+ if self.threeD:
+ self.data_aug_params = default_3D_augmentation_params
+ self.data_aug_params['rotation_x'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_y'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ self.data_aug_params['rotation_z'] = (-30. / 360 * 2. * np.pi, 30. / 360 * 2. * np.pi)
+ if self.do_dummy_2D_aug:
+ self.data_aug_params["dummy_2D"] = True
+ self.print_to_log_file("Using dummy2d data augmentation")
+ self.data_aug_params["elastic_deform_alpha"] = \
+ default_2D_augmentation_params["elastic_deform_alpha"]
+ self.data_aug_params["elastic_deform_sigma"] = \
+ default_2D_augmentation_params["elastic_deform_sigma"]
+ self.data_aug_params["rotation_x"] = default_2D_augmentation_params["rotation_x"]
+ else:
+ self.do_dummy_2D_aug = False
+ if max(self.patch_size) / min(self.patch_size) > 1.5:
+ default_2D_augmentation_params['rotation_x'] = (-180. / 360 * 2. * np.pi, 180. / 360 * 2. * np.pi)
+ self.data_aug_params = default_2D_augmentation_params
+ self.data_aug_params["mask_was_used_for_normalization"] = self.use_mask_for_norm
+
+ if self.do_dummy_2D_aug:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size[1:],
+ self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+ self.basic_generator_patch_size = np.array([self.patch_size[0]] + list(self.basic_generator_patch_size))
+ else:
+ self.basic_generator_patch_size = get_patch_size(self.patch_size, self.data_aug_params['rotation_x'],
+ self.data_aug_params['rotation_y'],
+ self.data_aug_params['rotation_z'],
+ self.data_aug_params['scale_range'])
+
+ self.data_aug_params["scale_range"] = (0.65, 1.6)
+
+ self.data_aug_params["do_elastic"] = True
+ self.data_aug_params["elastic_deform_alpha"] = (0., 1300.)
+ self.data_aug_params["elastic_deform_sigma"] = (9., 15.)
+ self.data_aug_params["p_eldef"] = 0.2
+
+ self.data_aug_params['selected_seg_channels'] = [0]
+
+ self.data_aug_params['gamma_range'] = (0.6, 2)
+
+ self.data_aug_params['patch_size_for_spatialtransform'] = self.patch_size
+
+ def initialize(self, training=True, force_load_plans=False):
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_insaneDA_augmentation(self.dl_tr, self.dl_val,
+ self.data_aug_params[
+ 'patch_size_for_spatialtransform'],
+ self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noDA.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a1d84f214fff11b95101e9fc6fd5e2f8a9e02f2
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noDA.py
@@ -0,0 +1,143 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Tuple
+
+import numpy as np
+from batchgenerators.utilities.file_and_folder_operations import join, maybe_mkdir_p
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.data_augmentation_noDA import get_no_augmentation
+from nnunet.training.dataloading.dataset_loading import unpack_dataset, DataLoader3D, DataLoader2D
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch import nn
+
+
+class nnUNetTrainerV2_noDataAugmentation(nnUNetTrainerV2):
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ # important because we need to know in validation and inference that we did not mirror in training
+ self.data_aug_params["do_mirror"] = False
+ self.data_aug_params["mirror_axes"] = tuple()
+
+ def get_basic_generators(self):
+ self.load_dataset()
+ self.do_split()
+
+ if self.threeD:
+ dl_tr = DataLoader3D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
+ False, oversample_foreground_percent=self.oversample_foreground_percent
+ , pad_mode="constant", pad_sides=self.pad_all_sides)
+ dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ else:
+ dl_tr = DataLoader2D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
+ transpose=self.plans.get('transpose_forward'),
+ oversample_foreground_percent=self.oversample_foreground_percent
+ , pad_mode="constant", pad_sides=self.pad_all_sides)
+ dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
+ transpose=self.plans.get('transpose_forward'),
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ return dl_tr, dl_val
+
+ def initialize(self, training=True, force_load_plans=False):
+ if not self.was_initialized:
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ ################# Here we wrap the loss for deep supervision ############
+ # we need to know the number of outputs of the network
+ net_numpool = len(self.net_num_pool_op_kernel_sizes)
+
+ # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases
+ # this gives higher resolution outputs more weight in the loss
+ weights = np.array([1 / (2 ** i) for i in range(net_numpool)])
+
+ # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1
+ mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)])
+ weights[~mask] = 0
+ weights = weights / weights.sum()
+
+ # now wrap the loss
+ self.loss = MultipleOutputLoss2(self.loss, weights)
+ ################# END ###################
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print(
+ "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+
+ self.tr_gen, self.val_gen = get_no_augmentation(self.dl_tr, self.dl_val,
+ params=self.data_aug_params,
+ deep_supervision_scales=self.deep_supervision_scales,
+ pin_memory=self.pin_memory)
+
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+
+ self.initialize_network()
+ self.initialize_optimizer_and_scheduler()
+
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ else:
+ self.print_to_log_file('self.was_initialized is True, not running self.initialize again')
+ self.was_initialized = True
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ """
+ We need to wrap this because we need to enforce self.network.do_ds = False for prediction
+
+ """
+ ds = self.network.do_ds
+ if do_mirroring:
+ print("WARNING! do_mirroring was True but we cannot do that because we trained without mirroring. "
+ "do_mirroring was set to False")
+ do_mirroring = False
+ self.network.do_ds = False
+ ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ self.network.do_ds = ds
+ return ret
+
+
+
+nnUNetTrainerV2_noDataAugmentation_copy1 = nnUNetTrainerV2_noDataAugmentation
+nnUNetTrainerV2_noDataAugmentation_copy2 = nnUNetTrainerV2_noDataAugmentation
+nnUNetTrainerV2_noDataAugmentation_copy3 = nnUNetTrainerV2_noDataAugmentation
+nnUNetTrainerV2_noDataAugmentation_copy4 = nnUNetTrainerV2_noDataAugmentation
diff --git a/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noMirroring.py b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noMirroring.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0baa522d8732edf196e39c86528db9ba96c65f6
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/data_augmentation/nnUNetTrainerV2_noMirroring.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_noMirroring(nnUNetTrainerV2):
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ """
+ We need to wrap this because we need to enforce self.network.do_ds = False for prediction
+ """
+ ds = self.network.do_ds
+ if do_mirroring:
+ print("WARNING! do_mirroring was True but we cannot do that because we trained without mirroring. "
+ "do_mirroring was set to False")
+ do_mirroring = False
+ self.network.do_ds = False
+ ret = super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian,
+ overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug,
+ all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+ self.network.do_ds = ds
+ return ret
+
+ def setup_DA_params(self):
+ super().setup_DA_params()
+ self.data_aug_params["do_mirror"] = False
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/__init__.py b/nnunet/training/network_training/nnUNet_variants/loss_function/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceBD.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceBD.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc9e5d0901ae41b3c6a42a9e8d7d0850ce4a8c0d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceBD.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_ForceBD(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ batch_dice = True
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceSD.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceSD.py
new file mode 100644
index 0000000000000000000000000000000000000000..4fac9015d58ff04152866ac0bc9c55c88eab032c
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_ForceSD.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_ForceSD(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ batch_dice = False
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CE.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CE.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6cf68adf0f28c4a7d9f9b123871906f9579ee76
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CE.py
@@ -0,0 +1,23 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_Loss_CE(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = RobustCrossEntropyLoss()
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CEGDL.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CEGDL.py
new file mode 100644
index 0000000000000000000000000000000000000000..54fb4715b469e591b6a5b909562b295df2b3de06
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_CEGDL.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import GDL_and_CE_loss
+
+
+class nnUNetTrainerV2_Loss_CEGDL(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = GDL_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {})
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice.py
new file mode 100644
index 0000000000000000000000000000000000000000..683e2813814fce7102427b853da080d517f62258
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice.py
@@ -0,0 +1,35 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import SoftDiceLoss
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_Loss_Dice(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = SoftDiceLoss(**{'apply_nonlin': softmax_helper, 'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False})
+
+
+class nnUNetTrainerV2_Loss_DicewithBG(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = SoftDiceLoss(**{'apply_nonlin': softmax_helper, 'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': True})
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceCE_noSmooth.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceCE_noSmooth.py
new file mode 100644
index 0000000000000000000000000000000000000000..204ee5912a041b876351d739d8b2f1300a2a086d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceCE_noSmooth.py
@@ -0,0 +1,27 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import SoftDiceLoss, DC_and_CE_loss
+
+
+class nnUNetTrainerV2_Loss_DiceCE_noSmooth(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 0, 'do_bg': False}, {})
+
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceTopK10.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceTopK10.py
new file mode 100644
index 0000000000000000000000000000000000000000..751cda8fba3dde2e64a4363a32d59d4b7e4074e7
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_DiceTopK10.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import DC_and_topk_loss
+
+
+class nnUNetTrainerV2_Loss_DiceTopK10(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_topk_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False},
+ {'k': 10})
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_lr1en3.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_lr1en3.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce37df3163793e687f3a0a78922009e1e15e6f55
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_lr1en3.py
@@ -0,0 +1,34 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNet_variants.loss_function.nnUNetTrainerV2_Loss_Dice import \
+ nnUNetTrainerV2_Loss_Dice, nnUNetTrainerV2_Loss_DicewithBG
+
+
+class nnUNetTrainerV2_Loss_Dice_LR1en3(nnUNetTrainerV2_Loss_Dice):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+
+
+class nnUNetTrainerV2_Loss_DicewithBG_LR1en3(nnUNetTrainerV2_Loss_DicewithBG):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_squared.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_squared.py
new file mode 100644
index 0000000000000000000000000000000000000000..58bf7708e5fe66eadfeb47177aa92f6dd42389d3
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_Dice_squared.py
@@ -0,0 +1,27 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import SoftDiceLossSquared
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_Loss_Dice_squared(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+ self.loss = SoftDiceLossSquared(**{'apply_nonlin': softmax_helper, 'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False})
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_MCC.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_MCC.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a3fe3f9d695f4f3f6b27d92d2007b4338704adc
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_MCC.py
@@ -0,0 +1,37 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.dice_loss import MCCLoss
+from nnunet.utilities.nd_softmax import softmax_helper
+
+
+class nnUNetTrainerV2_Loss_MCC(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+ self.loss = MCCLoss(apply_nonlin=softmax_helper, batch_mcc=self.batch_dice, do_bg=True, smooth=0.0)
+
+
+class nnUNetTrainerV2_Loss_MCCnoBG(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+ self.loss = MCCLoss(apply_nonlin=softmax_helper, batch_mcc=self.batch_dice, do_bg=False, smooth=0.0)
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_TopK10.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_TopK10.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2b0d56bcc69e913b49b0aa70db44d076ebc1ea7
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_Loss_TopK10.py
@@ -0,0 +1,32 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.loss_functions.TopK_loss import TopKLoss
+
+
+class nnUNetTrainerV2_Loss_TopK10(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = TopKLoss(k=10)
+
+
+nnUNetTrainerV2_Loss_TopK10_copy1 = nnUNetTrainerV2_Loss_TopK10
+nnUNetTrainerV2_Loss_TopK10_copy2 = nnUNetTrainerV2_Loss_TopK10
+nnUNetTrainerV2_Loss_TopK10_copy3 = nnUNetTrainerV2_Loss_TopK10
+nnUNetTrainerV2_Loss_TopK10_copy4 = nnUNetTrainerV2_Loss_TopK10
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_focalLoss.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_focalLoss.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfb0ad8c0ee81d7c000147ceb33bc7ef76019564
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_focalLoss.py
@@ -0,0 +1,28 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from nnunet.training.loss_functions.focal_loss import FocalLossV2
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch import nn
+
+
+class nnUNetTrainerV2_SegLoss_Focal(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
+ unpack_data, deterministic, fp16)
+ print("Setting up self.loss = Focal_loss({'alpha':0.75, 'gamma':2, 'smooth':1e-5})")
+ self.loss = FocalLossV2(apply_nonlin=nn.Softmax(dim=1), **{'alpha':0.5, 'gamma':2, 'smooth':1e-5})
+
+
diff --git a/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_graduallyTransitionFromCEToDice.py b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_graduallyTransitionFromCEToDice.py
new file mode 100644
index 0000000000000000000000000000000000000000..77159b6bea28d166eb73d2431f1c34c5d0c50189
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/loss_function/nnUNetTrainerV2_graduallyTransitionFromCEToDice.py
@@ -0,0 +1,58 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.loss_functions.deep_supervision import MultipleOutputLoss2
+from nnunet.training.loss_functions.dice_loss import DC_and_CE_loss
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_graduallyTransitionFromCEToDice(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}, weight_ce=2, weight_dice=0)
+
+ def update_loss(self):
+ # we train the first 500 epochs with CE, then transition to Dice between 500 and 750. The last 250 epochs will be Dice only
+
+ if self.epoch <= 500:
+ weight_ce = 2
+ weight_dice = 0
+ elif 500 < self.epoch <= 750:
+ weight_ce = 2 - 2 / 250 * (self.epoch - 500)
+ weight_dice = 0 + 2 / 250 * (self.epoch - 500)
+ elif 750 < self.epoch <= self.max_num_epochs:
+ weight_ce = 0
+ weight_dice = 2
+ else:
+ raise RuntimeError("Invalid epoch: %d" % self.epoch)
+
+ self.print_to_log_file("weight ce", weight_ce, "weight dice", weight_dice)
+
+ self.loss = DC_and_CE_loss({'batch_dice': self.batch_dice, 'smooth': 1e-5, 'do_bg': False}, {}, weight_ce=weight_ce,
+ weight_dice=weight_dice)
+
+ self.loss = MultipleOutputLoss2(self.loss, self.ds_loss_weights)
+
+ def on_epoch_end(self):
+ ret = super().on_epoch_end()
+ self.update_loss()
+ return ret
+
+ def load_checkpoint_ram(self, checkpoint, train=True):
+ ret = super().load_checkpoint_ram(checkpoint, train)
+ self.update_loss()
+ return ret
diff --git a/nnunet/training/network_training/nnUNet_variants/miscellaneous/__init__.py b/nnunet/training/network_training/nnUNet_variants/miscellaneous/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/miscellaneous/nnUNetTrainerV2_fullEvals.py b/nnunet/training/network_training/nnUNet_variants/miscellaneous/nnUNetTrainerV2_fullEvals.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c68bfab22f7047b04a6b8af8fbdd63d15fa46cc
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/miscellaneous/nnUNetTrainerV2_fullEvals.py
@@ -0,0 +1,195 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from multiprocessing.pool import Pool
+from time import time
+
+import numpy as np
+import torch
+from nnunet.configuration import default_num_threads
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.evaluation.region_based_evaluation import evaluate_regions, get_brats_regions
+
+
+class nnUNetTrainerV2_fullEvals(nnUNetTrainerV2):
+ """
+ this trainer only works for brats and nothing else
+ """
+
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.validate_every = 1
+ self.evaluation_regions = get_brats_regions()
+ self.num_val_batches_per_epoch = 0 # we dont need this because this does not evaluate on full images
+
+ def finish_online_evaluation(self):
+ pass
+
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ force_separate_z: bool = None, interpolation_order: int = 3, interpolation_order_z=0):
+ """
+ disable nnunet postprocessing. this would just waste computation time and does not benefit brats
+
+ !!!We run this with use_sliding_window=False per default (see on_epoch_end). This triggers fully convolutional
+ inference. THIS ONLY MAKES SENSE WHEN TRAINING ON FULL IMAGES! Make sure use_sliding_window=True when running
+ with default patch size (128x128x128)!!!
+
+ per default this does not use test time data augmentation (mirroring). The reference implementation, however,
+ does. I disabled it here because this eats up a lot of computation time
+
+ """
+ validation_start = time()
+
+ current_mode = self.network.training
+ self.network.eval()
+
+ assert self.was_initialized, "must initialize, ideally with checkpoint (or train first)"
+ if self.dataset_val is None:
+ self.load_dataset()
+ self.do_split()
+
+ # predictions as they come from the network go here
+ output_folder = join(self.output_folder, validation_folder_name)
+ maybe_mkdir_p(output_folder)
+
+ # this is for debug purposes
+ my_input_args = {'do_mirroring': do_mirroring,
+ 'use_sliding_window': use_sliding_window,
+ 'step_size': step_size,
+ 'save_softmax': save_softmax,
+ 'use_gaussian': use_gaussian,
+ 'overwrite': overwrite,
+ 'validation_folder_name': validation_folder_name,
+ 'debug': debug,
+ 'all_in_gpu': all_in_gpu,
+ 'force_separate_z': force_separate_z,
+ 'interpolation_order': interpolation_order,
+ 'interpolation_order_z': interpolation_order_z,
+ }
+ save_json(my_input_args, join(output_folder, "validation_args.json"))
+
+ if do_mirroring:
+ if not self.data_aug_params['do_mirror']:
+ raise RuntimeError("We did not train with mirroring so you cannot do inference with mirroring enabled")
+ mirror_axes = self.data_aug_params['mirror_axes']
+ else:
+ mirror_axes = ()
+
+ export_pool = Pool(default_num_threads)
+ results = []
+
+ for k in self.dataset_val.keys():
+ properties = load_pickle(self.dataset[k]['properties_file'])
+ fname = properties['list_of_data_files'][0].split("/")[-1][:-12]
+ if overwrite or (not isfile(join(output_folder, fname + ".nii.gz"))) or \
+ (save_softmax and not isfile(join(output_folder, fname + ".npz"))):
+ data = np.load(self.dataset[k]['data_file'])['data']
+
+ #print(k, data.shape)
+
+ softmax_pred = self.predict_preprocessed_data_return_seg_and_softmax(data[:-1],
+ do_mirroring=do_mirroring,
+ mirror_axes=mirror_axes,
+ use_sliding_window=use_sliding_window,
+ step_size=step_size,
+ use_gaussian=use_gaussian,
+ all_in_gpu=all_in_gpu,
+ verbose=False,
+ mixed_precision=self.fp16)[1]
+
+ # this does not do anything in brats -> remove this line
+ # softmax_pred = softmax_pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ if save_softmax:
+ softmax_fname = join(output_folder, fname + ".npz")
+ else:
+ softmax_fname = None
+
+ results.append(export_pool.starmap_async(save_segmentation_nifti_from_softmax,
+ ((softmax_pred, join(output_folder, fname + ".nii.gz"),
+ properties, interpolation_order, None, None, None,
+ softmax_fname, None, force_separate_z,
+ interpolation_order_z, False),
+ )
+ )
+ )
+
+ _ = [i.get() for i in results]
+ self.print_to_log_file("finished prediction")
+
+ # evaluate raw predictions
+ self.print_to_log_file("evaluation of raw predictions")
+
+ # this writes a csv file into output_folder
+ evaluate_regions(output_folder, self.gt_niftis_folder, self.evaluation_regions)
+ csv_file = np.loadtxt(join(output_folder, 'summary.csv'), skiprows=1, dtype=str, delimiter=',')[:, 1:]
+
+ # these are the values that are compute with np.nanmean aggregation
+ whole, core, enhancing = csv_file[-4, :].astype(float)
+
+ # do some cleanup
+ if torch.cuda.is_available():
+ torch.cuda.empty_cache()
+
+ self.network.train(current_mode)
+ validation_end = time()
+ self.print_to_log_file('Running the validation took %f seconds' % (validation_end - validation_start))
+ self.print_to_log_file('(the time needed for validation is included in the total epoch time!)')
+
+ return whole, core, enhancing
+
+ def on_epoch_end(self):
+ return_value = True
+
+ # on epoch end is called before the epoch counter is incremented, so we need to do that here to get the correct epoch number
+ if (self.epoch + 1) % self.validate_every == 0:
+ whole, core, enhancing = self.validate(do_mirroring=False, use_sliding_window=True,
+ step_size=0.5,
+ save_softmax=False,
+ use_gaussian=True, overwrite=True,
+ validation_folder_name='validation_after_ep_%04.0d' % self.epoch,
+ debug=False, all_in_gpu=True)
+
+ here = np.mean((whole, core, enhancing))
+
+ self.print_to_log_file("After epoch %d: whole %0.4f core %0.4f enhancing: %0.4f" %
+ (self.epoch, whole, core, enhancing))
+ self.print_to_log_file("Mean: %0.4f" % here)
+
+ # now we need to figure out if we are done
+ fully_trained_nnunet = (0.911, 0.8739, 0.7848)
+ mean_dice = np.mean(fully_trained_nnunet)
+ target = 0.97 * mean_dice
+
+ self.all_val_eval_metrics.append(here)
+ self.print_to_log_file("Target mean: %0.4f" % target)
+
+ if here >= target:
+ self.print_to_log_file("I am done!")
+ self.save_checkpoint(join(self.output_folder, "model_final_checkpoint.model"))
+ return_value = False # this triggers early stopping
+
+ ret_old = super().on_epoch_end()
+ # if we do not achieve the target accuracy in 1000 epochs then we need to stop the training. This is not built
+ # to run longer than 1000 epochs
+ if not ret_old:
+ return_value = ret_old
+
+ return return_value
diff --git a/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerCE.py b/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerCE.py
new file mode 100644
index 0000000000000000000000000000000000000000..689dcbf552a647039a315d9121b5de3253585563
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerCE.py
@@ -0,0 +1,23 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from nnunet.training.loss_functions.crossentropy import RobustCrossEntropyLoss
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+
+
+class nnUNetTrainerCE(nnUNetTrainer):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super(nnUNetTrainerCE, self).__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage,
+ unpack_data, deterministic, fp16)
+ self.loss = RobustCrossEntropyLoss()
diff --git a/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerNoDA.py b/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerNoDA.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a83cf0f9add03510d581204cf3d192746f3154
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/nnUNetTrainerNoDA.py
@@ -0,0 +1,88 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import matplotlib
+from batchgenerators.utilities.file_and_folder_operations import maybe_mkdir_p, join
+from nnunet.network_architecture.neural_network import SegmentationNetwork
+from nnunet.training.data_augmentation.data_augmentation_noDA import get_no_augmentation
+from nnunet.training.dataloading.dataset_loading import unpack_dataset, DataLoader3D, DataLoader2D
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from torch import nn
+
+matplotlib.use("agg")
+
+
+class nnUNetTrainerNoDA(nnUNetTrainer):
+ def get_basic_generators(self):
+ self.load_dataset()
+ self.do_split()
+
+ if self.threeD:
+ dl_tr = DataLoader3D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
+ False, oversample_foreground_percent=self.oversample_foreground_percent
+ , pad_mode="constant", pad_sides=self.pad_all_sides)
+ dl_val = DataLoader3D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size, False,
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ else:
+ dl_tr = DataLoader2D(self.dataset_tr, self.patch_size, self.patch_size, self.batch_size,
+ transpose=self.plans.get('transpose_forward'),
+ oversample_foreground_percent=self.oversample_foreground_percent
+ , pad_mode="constant", pad_sides=self.pad_all_sides)
+ dl_val = DataLoader2D(self.dataset_val, self.patch_size, self.patch_size, self.batch_size,
+ transpose=self.plans.get('transpose_forward'),
+ oversample_foreground_percent=self.oversample_foreground_percent,
+ pad_mode="constant", pad_sides=self.pad_all_sides)
+ return dl_tr, dl_val
+
+ def initialize(self, training=True, force_load_plans=False):
+ """
+ For prediction of test cases just set training=False, this will prevent loading of training data and
+ training batchgenerator initialization
+ :param training:
+ :return:
+ """
+
+ maybe_mkdir_p(self.output_folder)
+
+ if force_load_plans or (self.plans is None):
+ self.load_plans_file()
+
+ self.process_plans(self.plans)
+
+ self.setup_DA_params()
+
+ self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] +
+ "_stage%d" % self.stage)
+ if training:
+ self.dl_tr, self.dl_val = self.get_basic_generators()
+ if self.unpack_data:
+ print("unpacking dataset")
+ unpack_dataset(self.folder_with_preprocessed_data)
+ print("done")
+ else:
+ print("INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you "
+ "will wait all winter for your model to finish!")
+ self.tr_gen, self.val_gen = get_no_augmentation(self.dl_tr, self.dl_val, params=self.data_aug_params)
+ self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())),
+ also_print_to_console=False)
+ self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())),
+ also_print_to_console=False)
+ else:
+ pass
+ self.initialize_network()
+ assert isinstance(self.network, (SegmentationNetwork, nn.DataParallel))
+ self.was_initialized = True
+ self.data_aug_params['mirror_axes'] = ()
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/__init__.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6a01f7c902bccb1375cb66c2096353ed35a327e
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam.py
@@ -0,0 +1,30 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_Adam(nnUNetTrainerV2):
+
+ def initialize_optimizer_and_scheduler(self):
+ self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay, amsgrad=True)
+ self.lr_scheduler = None
+
+
+nnUNetTrainerV2_Adam_copy1 = nnUNetTrainerV2_Adam
+nnUNetTrainerV2_Adam_copy2 = nnUNetTrainerV2_Adam
+nnUNetTrainerV2_Adam_copy3 = nnUNetTrainerV2_Adam
+nnUNetTrainerV2_Adam_copy4 = nnUNetTrainerV2_Adam
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_ReduceOnPlateau.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_ReduceOnPlateau.py
new file mode 100644
index 0000000000000000000000000000000000000000..b467ad997af23eace8f8de25fc52dfcbe29b702e
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_ReduceOnPlateau.py
@@ -0,0 +1,55 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch.optim import lr_scheduler
+
+
+class nnUNetTrainerV2_Adam_ReduceOnPlateau(nnUNetTrainerV2):
+ """
+ Same schedule as nnUNetTrainer
+ """
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 3e-4
+
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.Adam(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ amsgrad=True)
+ self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
+ patience=self.lr_scheduler_patience,
+ verbose=True, threshold=self.lr_scheduler_eps,
+ threshold_mode="abs")
+
+ def maybe_update_lr(self, epoch=None):
+ # maybe update learning rate
+ if self.lr_scheduler is not None:
+ assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
+
+ if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
+ # lr scheduler is updated with moving average val loss. should be more robust
+ if self.epoch > 0 and self.train_loss_MA is not None: # otherwise self.train_loss_MA is None
+ self.lr_scheduler.step(self.train_loss_MA)
+ else:
+ self.lr_scheduler.step(self.epoch + 1)
+ self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
+
+ def on_epoch_end(self):
+ return nnUNetTrainer.on_epoch_end(self)
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_lr_3en4.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_lr_3en4.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c13a11e6a96f7299e4695416dc52a373c0e5b3d
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Adam_lr_3en4.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNet_variants.optimizer_and_lr.nnUNetTrainerV2_Adam import nnUNetTrainerV2_Adam
+
+
+class nnUNetTrainerV2_Adam_nnUNetTrainerlr(nnUNetTrainerV2_Adam):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 3e-4
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr1en2.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr1en2.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1f33aa088c2b0fcc8296ed3b7fd97471ab836d4
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr1en2.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.optimizer.ranger import Ranger
+
+
+class nnUNetTrainerV2_Ranger_lr1en2(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-2
+
+ def initialize_optimizer_and_scheduler(self):
+ self.optimizer = Ranger(self.network.parameters(), self.initial_lr, k=6, N_sma_threshhold=5,
+ weight_decay=self.weight_decay)
+ self.lr_scheduler = None
+
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en3.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en3.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d10da13cc45edcca4e56ebf5ded19d3f7805b23
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en3.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.optimizer.ranger import Ranger
+
+
+class nnUNetTrainerV2_Ranger_lr3en3(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 3e-3
+
+ def initialize_optimizer_and_scheduler(self):
+ self.optimizer = Ranger(self.network.parameters(), self.initial_lr, k=6, N_sma_threshhold=5,
+ weight_decay=self.weight_decay)
+ self.lr_scheduler = None
+
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en4.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en4.py
new file mode 100644
index 0000000000000000000000000000000000000000..11c544add0d559c08c6ac9a9d3d3b5eea4f18376
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_Ranger_lr3en4.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from nnunet.training.optimizer.ranger import Ranger
+
+
+class nnUNetTrainerV2_Ranger_lr3en4(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 3e-4
+
+ def initialize_optimizer_and_scheduler(self):
+ self.optimizer = Ranger(self.network.parameters(), self.initial_lr, k=6, N_sma_threshhold=5,
+ weight_decay=self.weight_decay)
+ self.lr_scheduler = None
+
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_ReduceOnPlateau.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_ReduceOnPlateau.py
new file mode 100644
index 0000000000000000000000000000000000000000..d89a7458776c93889d3ad9bd9f3e1fb5b9804a54
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_ReduceOnPlateau.py
@@ -0,0 +1,50 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.training.network_training.nnUNetTrainer import nnUNetTrainer
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+from torch.optim import lr_scheduler
+
+
+class nnUNetTrainerV2_SGD_ReduceOnPlateau(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+ def initialize_optimizer_and_scheduler(self):
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ self.lr_scheduler = lr_scheduler.ReduceLROnPlateau(self.optimizer, mode='min', factor=0.2,
+ patience=self.lr_scheduler_patience,
+ verbose=True, threshold=self.lr_scheduler_eps,
+ threshold_mode="abs")
+
+ def maybe_update_lr(self, epoch=None):
+ # maybe update learning rate
+ if self.lr_scheduler is not None:
+ assert isinstance(self.lr_scheduler, (lr_scheduler.ReduceLROnPlateau, lr_scheduler._LRScheduler))
+
+ if isinstance(self.lr_scheduler, lr_scheduler.ReduceLROnPlateau):
+ # lr scheduler is updated with moving average val loss. should be more robust
+ if self.epoch > 0: # otherwise self.train_loss_MA is None
+ self.lr_scheduler.step(self.train_loss_MA)
+ else:
+ self.lr_scheduler.step(self.epoch + 1)
+ self.print_to_log_file("lr is now (scheduler) %s" % str(self.optimizer.param_groups[0]['lr']))
+
+ def on_epoch_end(self):
+ return nnUNetTrainer.on_epoch_end(self)
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule.py
new file mode 100644
index 0000000000000000000000000000000000000000..34a8bea46251bedb9fbfde4bf9f0d86a5216f806
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule.py
@@ -0,0 +1,43 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_SGD_fixedSchedule(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+ def maybe_update_lr(self, epoch=None):
+ if epoch is None:
+ ep = self.epoch + 1
+ else:
+ ep = epoch
+
+ if 0 <= ep < 500:
+ new_lr = self.initial_lr
+ elif 500 <= ep < 675:
+ new_lr = self.initial_lr * 0.1
+ elif 675 <= ep < 850:
+ new_lr = self.initial_lr * 0.01
+ elif ep >= 850:
+ new_lr = self.initial_lr * 0.001
+ else:
+ raise RuntimeError("Really unexpected things happened, ep=%d" % ep)
+
+ self.optimizer.param_groups[0]['lr'] = new_lr
+ self.print_to_log_file("lr:", self.optimizer.param_groups[0]['lr'])
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule2.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule2.py
new file mode 100644
index 0000000000000000000000000000000000000000..106de67ecf5d340acf4a9d4864fb9720da3f8d62
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_fixedSchedule2.py
@@ -0,0 +1,47 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.learning_rate.poly_lr import poly_lr
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_SGD_fixedSchedule2(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+
+ def maybe_update_lr(self, epoch=None):
+ """
+ here we go one step, then use polyLR
+ :param epoch:
+ :return:
+ """
+ if epoch is None:
+ ep = self.epoch + 1
+ else:
+ ep = epoch
+
+ if 0 <= ep < 500:
+ new_lr = self.initial_lr
+ elif 500 <= ep < 675:
+ new_lr = self.initial_lr * 0.1
+ elif ep >= 675:
+ new_lr = poly_lr(ep - 675, self.max_num_epochs - 675, self.initial_lr * 0.1, 0.9)
+ else:
+ raise RuntimeError("Really unexpected things happened, ep=%d" % ep)
+
+ self.optimizer.param_groups[0]['lr'] = new_lr
+ self.print_to_log_file("lr:", self.optimizer.param_groups[0]['lr'])
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_lrs.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_lrs.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1a326441cce20b9d85b2c517a2dd56dd3dbbe0
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_SGD_lrs.py
@@ -0,0 +1,33 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_SGD_lr1en1(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-1
+
+
+class nnUNetTrainerV2_SGD_lr1en3(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.initial_lr = 1e-3
+
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_cycleAtEnd.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_cycleAtEnd.py
new file mode 100644
index 0000000000000000000000000000000000000000..91d07192513628fefa8a1b33a51037fe4dcb3600
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_cycleAtEnd.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.learning_rate.poly_lr import poly_lr
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+import matplotlib.pyplot as plt
+
+
+def cycle_lr(current_epoch, cycle_length=100, min_lr=1e-6, max_lr=1e-3):
+ num_rising = cycle_length // 2
+ epoch = current_epoch % cycle_length
+ if epoch < num_rising:
+ lr = min_lr + (max_lr - min_lr) / num_rising * epoch
+ else:
+ lr = max_lr - (max_lr - min_lr) / num_rising * (epoch - num_rising)
+ return lr
+
+
+def plot_cycle_lr():
+ xvals = list(range(1000))
+ yvals = [cycle_lr(i, 100, 1e-6, 1e-3) for i in xvals]
+ plt.plot(xvals, yvals)
+ plt.show()
+ plt.savefig("/home/fabian/temp.png")
+ plt.close()
+
+
+class nnUNetTrainerV2_cycleAtEnd(nnUNetTrainerV2):
+ """
+ after 1000 epoch, run one iteration through the cycle lr schedule. I want to see if the train loss starts
+ increasing again
+ """
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 1100
+
+ def maybe_update_lr(self, epoch=None):
+ if epoch is None:
+ ep = self.epoch + 1
+ else:
+ ep = epoch
+
+ if ep < 1000:
+ self.optimizer.param_groups[0]['lr'] = poly_lr(ep, 1000, self.initial_lr, 0.9)
+ self.print_to_log_file("lr:", poly_lr(ep, 1000, self.initial_lr, 0.9))
+ else:
+ new_lr = cycle_lr(ep, 100, min_lr=1e-6, max_lr=1e-3) # we don't go all the way back up to initial lr
+ self.optimizer.param_groups[0]['lr'] = new_lr
+ self.print_to_log_file("lr:", new_lr)
+
+
+class nnUNetTrainerV2_cycleAtEnd2(nnUNetTrainerV2):
+ """
+ after 1000 epoch, run one iteration through the cycle lr schedule. I want to see if the train loss starts
+ increasing again
+ """
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 1200
+
+ def maybe_update_lr(self, epoch=None):
+ if epoch is None:
+ ep = self.epoch + 1
+ else:
+ ep = epoch
+
+ if ep < 1000:
+ self.optimizer.param_groups[0]['lr'] = poly_lr(ep, 1000, self.initial_lr, 0.9)
+ self.print_to_log_file("lr:", poly_lr(ep, 1000, self.initial_lr, 0.9))
+ else:
+ new_lr = cycle_lr(ep, 200, min_lr=1e-6, max_lr=1e-2) # we don't go all the way back up to initial lr
+ self.optimizer.param_groups[0]['lr'] = new_lr
+ self.print_to_log_file("lr:", new_lr)
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_fp16.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_fp16.py
new file mode 100644
index 0000000000000000000000000000000000000000..b705bd20e358c70797d7edbd0a8fb22fee5eeca9
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_fp16.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_fp16(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ assert fp16, "This one only accepts fp16=True"
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09.py
new file mode 100644
index 0000000000000000000000000000000000000000..0cc9a922d36fd5f84c8a27195d4511ea31f66a24
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_momentum09(nnUNetTrainerV2):
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.9, nesterov=True)
+ self.lr_scheduler = None
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum095.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum095.py
new file mode 100644
index 0000000000000000000000000000000000000000..e046239297d9ce530fed13f48b31d711f179e7c8
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum095.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_momentum095(nnUNetTrainerV2):
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.95, nesterov=True)
+ self.lr_scheduler = None
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum098.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum098.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d9072c61cb089f4765e9f65d064f8f68d6d524
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum098.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_momentum098(nnUNetTrainerV2):
+ def initialize_optimizer_and_scheduler(self):
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.98, nesterov=True)
+ self.lr_scheduler = None
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09in2D.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09in2D.py
new file mode 100644
index 0000000000000000000000000000000000000000..83ffbec64ec86260209ae7673cf2b5f09256218e
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_momentum09in2D.py
@@ -0,0 +1,29 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_momentum09in2D(nnUNetTrainerV2):
+ def initialize_optimizer_and_scheduler(self):
+ if self.threeD:
+ momentum = 0.99
+ else:
+ momentum = 0.9
+ assert self.network is not None, "self.initialize_network must be called first"
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=momentum, nesterov=True)
+ self.lr_scheduler = None
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_reduceMomentumDuringTraining.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_reduceMomentumDuringTraining.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd8860fb33b52fd2d734fa58878458d6da803db0
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_reduceMomentumDuringTraining.py
@@ -0,0 +1,46 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_reduceMomentumDuringTraining(nnUNetTrainerV2):
+ """
+ This implementation will not work with LR scheduler!!!!!!!!!!
+
+ After epoch 800, linearly decrease momentum from 0.99 to 0.9
+ """
+ def initialize_optimizer_and_scheduler(self):
+ current_momentum = 0.99
+ min_momentum = 0.9
+
+ if self.epoch > 800:
+ current_momentum = current_momentum - (current_momentum - min_momentum) / 200 * (self.epoch - 800)
+
+ self.print_to_log_file("current momentum", current_momentum)
+ assert self.network is not None, "self.initialize_network must be called first"
+ if self.optimizer is None:
+ self.optimizer = torch.optim.SGD(self.network.parameters(), self.initial_lr, weight_decay=self.weight_decay,
+ momentum=0.99, nesterov=True)
+ else:
+ # can't reinstantiate because that would break NVIDIA AMP
+ self.optimizer.param_groups[0]["momentum"] = current_momentum
+ self.lr_scheduler = None
+
+ def on_epoch_end(self):
+ self.initialize_optimizer_and_scheduler()
+ return super().on_epoch_end()
diff --git a/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_warmup.py b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_warmup.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bac0760899cdf74351210128ef3505e9a52cf74
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/optimizer_and_lr/nnUNetTrainerV2_warmup.py
@@ -0,0 +1,39 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_warmup(nnUNetTrainerV2):
+ def __init__(self, plans_file, fold, output_folder=None, dataset_directory=None, batch_dice=True, stage=None,
+ unpack_data=True, deterministic=True, fp16=False):
+ super().__init__(plans_file, fold, output_folder, dataset_directory, batch_dice, stage, unpack_data,
+ deterministic, fp16)
+ self.max_num_epochs = 1050
+
+ def maybe_update_lr(self, epoch=None):
+ if self.epoch < 50:
+ # epoch 49 is max
+ # we increase lr linearly from 0 to initial_lr
+ lr = (self.epoch + 1) / 50 * self.initial_lr
+ self.optimizer.param_groups[0]['lr'] = lr
+ self.print_to_log_file("epoch:", self.epoch, "lr:", lr)
+ else:
+ if epoch is not None:
+ ep = epoch - 49
+ else:
+ ep = self.epoch - 49
+ assert ep > 0, "epoch must be >0"
+ return super().maybe_update_lr(ep)
diff --git a/nnunet/training/network_training/nnUNet_variants/resampling/__init__.py b/nnunet/training/network_training/nnUNet_variants/resampling/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/network_training/nnUNet_variants/resampling/nnUNetTrainerV2_resample33.py b/nnunet/training/network_training/nnUNet_variants/resampling/nnUNetTrainerV2_resample33.py
new file mode 100644
index 0000000000000000000000000000000000000000..7111e82f0428e7d913e717471189a10ce3bdb34b
--- /dev/null
+++ b/nnunet/training/network_training/nnUNet_variants/resampling/nnUNetTrainerV2_resample33.py
@@ -0,0 +1,56 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.inference.segmentation_export import save_segmentation_nifti_from_softmax
+from nnunet.training.network_training.nnUNetTrainerV2 import nnUNetTrainerV2
+
+
+class nnUNetTrainerV2_resample33(nnUNetTrainerV2):
+ def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True,
+ step_size: float = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True,
+ validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False,
+ segmentation_export_kwargs: dict = None, run_postprocessing_on_folds: bool = True):
+ return super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size,
+ save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite,
+ validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu,
+ segmentation_export_kwargs=segmentation_export_kwargs,
+ run_postprocessing_on_folds=run_postprocessing_on_folds)
+
+ def preprocess_predict_nifti(self, input_files, output_file=None, softmax_ouput_file=None,
+ mixed_precision: bool = True):
+ """
+ Use this to predict new data
+ :param input_files:
+ :param output_file:
+ :param softmax_ouput_file:
+ :param mixed_precision:
+ :return:
+ """
+ print("preprocessing...")
+ d, s, properties = self.preprocess_patient(input_files)
+ print("predicting...")
+ pred = self.predict_preprocessed_data_return_seg_and_softmax(d, do_mirroring=self.data_aug_params["do_mirror"],
+ mirror_axes=self.data_aug_params['mirror_axes'],
+ use_sliding_window=True, step_size=0.5,
+ use_gaussian=True, pad_border_mode='constant',
+ pad_kwargs={'constant_values': 0},
+ all_in_gpu=True,
+ mixed_precision=mixed_precision)[1]
+ pred = pred.transpose([0] + [i + 1 for i in self.transpose_backward])
+
+ print("resampling to original spacing and nifti export...")
+ save_segmentation_nifti_from_softmax(pred, output_file, properties, 3, None, None, None, softmax_ouput_file,
+ None, force_separate_z=False, interpolation_order_z=3)
+ print("done")
diff --git a/nnunet/training/optimizer/__init__.py b/nnunet/training/optimizer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/nnunet/training/optimizer/ranger.py b/nnunet/training/optimizer/ranger.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8acbdf8616eae1aeca0b953879384ebe85de1f
--- /dev/null
+++ b/nnunet/training/optimizer/ranger.py
@@ -0,0 +1,152 @@
+############
+# https://github.com/lessw2020/Ranger-Deep-Learning-Optimizer
+# This code was taken from the repo above and was not created by me (Fabian)! Full credit goes to the original authors
+############
+
+import math
+import torch
+from torch.optim.optimizer import Optimizer
+
+
+class Ranger(Optimizer):
+
+ def __init__(self, params, lr=1e-3, alpha=0.5, k=6, N_sma_threshhold=5, betas=(.95, 0.999), eps=1e-5,
+ weight_decay=0):
+ # parameter checks
+ if not 0.0 <= alpha <= 1.0:
+ raise ValueError(f'Invalid slow update rate: {alpha}')
+ if not 1 <= k:
+ raise ValueError(f'Invalid lookahead steps: {k}')
+ if not lr > 0:
+ raise ValueError(f'Invalid Learning Rate: {lr}')
+ if not eps > 0:
+ raise ValueError(f'Invalid eps: {eps}')
+
+ # parameter comments:
+ # beta1 (momentum) of .95 seems to work better than .90...
+ # N_sma_threshold of 5 seems better in testing than 4.
+ # In both cases, worth testing on your dataset (.90 vs .95, 4 vs 5) to make sure which works best for you.
+
+ # prep defaults and init torch.optim base
+ defaults = dict(lr=lr, alpha=alpha, k=k, step_counter=0, betas=betas, N_sma_threshhold=N_sma_threshhold,
+ eps=eps, weight_decay=weight_decay)
+ super().__init__(params, defaults)
+
+ # adjustable threshold
+ self.N_sma_threshhold = N_sma_threshhold
+
+ # now we can get to work...
+ # removed as we now use step from RAdam...no need for duplicate step counting
+ # for group in self.param_groups:
+ # group["step_counter"] = 0
+ # print("group step counter init")
+
+ # look ahead params
+ self.alpha = alpha
+ self.k = k
+
+ # radam buffer for state
+ self.radam_buffer = [[None, None, None] for ind in range(10)]
+
+ # self.first_run_check=0
+
+ # lookahead weights
+ # 9/2/19 - lookahead param tensors have been moved to state storage.
+ # This should resolve issues with load/save where weights were left in GPU memory from first load, slowing down future runs.
+
+ # self.slow_weights = [[p.clone().detach() for p in group['params']]
+ # for group in self.param_groups]
+
+ # don't use grad for lookahead weights
+ # for w in it.chain(*self.slow_weights):
+ # w.requires_grad = False
+
+ def __setstate__(self, state):
+ print("set state called")
+ super(Ranger, self).__setstate__(state)
+
+ def step(self, closure=None):
+ loss = None
+ # note - below is commented out b/c I have other work that passes back the loss as a float, and thus not a callable closure.
+ # Uncomment if you need to use the actual closure...
+
+ # if closure is not None:
+ # loss = closure()
+
+ # Evaluate averages and grad, update param tensors
+ for group in self.param_groups:
+
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ grad = p.grad.data.float()
+ if grad.is_sparse:
+ raise RuntimeError('Ranger optimizer does not support sparse gradients')
+
+ p_data_fp32 = p.data.float()
+
+ state = self.state[p] # get state dict for this param
+
+ if len(state) == 0: # if first time to run...init dictionary with our desired entries
+ # if self.first_run_check==0:
+ # self.first_run_check=1
+ # print("Initializing slow buffer...should not see this at load from saved model!")
+ state['step'] = 0
+ state['exp_avg'] = torch.zeros_like(p_data_fp32)
+ state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
+
+ # look ahead weight storage now in state dict
+ state['slow_buffer'] = torch.empty_like(p.data)
+ state['slow_buffer'].copy_(p.data)
+
+ else:
+ state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
+ state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
+
+ # begin computations
+ exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
+ beta1, beta2 = group['betas']
+
+ # compute variance mov avg
+ exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # compute mean moving avg
+ exp_avg.mul_(beta1).add_(1 - beta1, grad)
+
+ state['step'] += 1
+
+ buffered = self.radam_buffer[int(state['step'] % 10)]
+ if state['step'] == buffered[0]:
+ N_sma, step_size = buffered[1], buffered[2]
+ else:
+ buffered[0] = state['step']
+ beta2_t = beta2 ** state['step']
+ N_sma_max = 2 / (1 - beta2) - 1
+ N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
+ buffered[1] = N_sma
+ if N_sma > self.N_sma_threshhold:
+ step_size = math.sqrt(
+ (1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
+ N_sma_max - 2)) / (1 - beta1 ** state['step'])
+ else:
+ step_size = 1.0 / (1 - beta1 ** state['step'])
+ buffered[2] = step_size
+
+ if group['weight_decay'] != 0:
+ p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
+
+ if N_sma > self.N_sma_threshhold:
+ denom = exp_avg_sq.sqrt().add_(group['eps'])
+ p_data_fp32.addcdiv_(-step_size * group['lr'], exp_avg, denom)
+ else:
+ p_data_fp32.add_(-step_size * group['lr'], exp_avg)
+
+ p.data.copy_(p_data_fp32)
+
+ # integrated look ahead...
+ # we do it at the param level instead of group level
+ if state['step'] % group['k'] == 0:
+ slow_p = state['slow_buffer'] # get access to slow param tensor
+ slow_p.add_(self.alpha, p.data - slow_p) # (fast weights - slow weights) * alpha
+ p.data.copy_(slow_p) # copy interpolated weights to RAdam param tensor
+
+ return loss
diff --git a/nnunet/utilities/__init__.py b/nnunet/utilities/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72b8078b9dddddf22182fec2555d8d118ea72622
--- /dev/null
+++ b/nnunet/utilities/__init__.py
@@ -0,0 +1,2 @@
+from __future__ import absolute_import
+from . import *
\ No newline at end of file
diff --git a/nnunet/utilities/distributed.py b/nnunet/utilities/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2dcab5680f4d472774e4715fba896e0ff05e155
--- /dev/null
+++ b/nnunet/utilities/distributed.py
@@ -0,0 +1,89 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import torch
+from torch import distributed
+from torch import autograd
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+
+def print_if_rank0(*args):
+ if distributed.get_rank() == 0:
+ print(*args)
+
+
+class awesome_allgather_function(autograd.Function):
+ @staticmethod
+ def forward(ctx, input):
+ world_size = distributed.get_world_size()
+ # create a destination list for the allgather. I'm assuming you're gathering from 3 workers.
+ allgather_list = [torch.empty_like(input) for _ in range(world_size)]
+ #if distributed.get_rank() == 0:
+ # import IPython;IPython.embed()
+ distributed.all_gather(allgather_list, input)
+ return torch.cat(allgather_list, dim=0)
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ #print_if_rank0("backward grad_output len", len(grad_output))
+ #print_if_rank0("backward grad_output shape", grad_output.shape)
+ grads_per_rank = grad_output.shape[0] // distributed.get_world_size()
+ rank = distributed.get_rank()
+ # We'll receive gradients for the entire catted forward output, so to mimic DataParallel,
+ # return only the slice that corresponds to this process's input:
+ sl = slice(rank * grads_per_rank, (rank + 1) * grads_per_rank)
+ #print("worker", rank, "backward slice", sl)
+ return grad_output[sl]
+
+
+if __name__ == "__main__":
+ import torch.distributed as dist
+ import argparse
+ from torch import nn
+ from torch.optim import Adam
+
+ argumentparser = argparse.ArgumentParser()
+ argumentparser.add_argument("--local_rank", type=int)
+ args = argumentparser.parse_args()
+
+ torch.cuda.set_device(args.local_rank)
+ dist.init_process_group(backend='nccl', init_method='env://')
+
+ rnd = torch.rand((5, 2)).cuda()
+
+ rnd_gathered = awesome_allgather_function.apply(rnd)
+ print("gathering random tensors\nbefore\b", rnd, "\nafter\n", rnd_gathered)
+
+ # so far this works as expected
+ print("now running a DDP model")
+ c = nn.Conv2d(2, 3, 3, 1, 1, 1, 1, True).cuda()
+ c = DDP(c)
+ opt = Adam(c.parameters())
+
+ bs = 5
+ if dist.get_rank() == 0:
+ bs = 4
+ inp = torch.rand((bs, 2, 5, 5)).cuda()
+
+ out = c(inp)
+ print("output_shape", out.shape)
+
+ out_gathered = awesome_allgather_function.apply(out)
+ print("output_shape_after_gather", out_gathered.shape)
+ # this also works
+
+ loss = out_gathered.sum()
+ loss.backward()
+ opt.step()
diff --git a/nnunet/utilities/file_conversions.py b/nnunet/utilities/file_conversions.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f3b4a843780c7708a31b84924c60a3cfcbd1b54
--- /dev/null
+++ b/nnunet/utilities/file_conversions.py
@@ -0,0 +1,115 @@
+from typing import Tuple, List, Union
+from skimage import io
+import SimpleITK as sitk
+import numpy as np
+import tifffile
+
+
+def convert_2d_image_to_nifti(input_filename: str, output_filename_truncated: str, spacing=(999, 1, 1),
+ transform=None, is_seg: bool = False) -> None:
+ """
+ Reads an image (must be a format that it recognized by skimage.io.imread) and converts it into a series of niftis.
+ The image can have an arbitrary number of input channels which will be exported separately (_0000.nii.gz,
+ _0001.nii.gz, etc for images and only .nii.gz for seg).
+ Spacing can be ignored most of the time.
+ !!!2D images are often natural images which do not have a voxel spacing that could be used for resampling. These images
+ must be resampled by you prior to converting them to nifti!!!
+
+ Datasets converted with this utility can only be used with the 2d U-Net configuration of nnU-Net
+
+ If Transform is not None it will be applied to the image after loading.
+
+ Segmentations will be converted to np.uint32!
+
+ :param is_seg:
+ :param transform:
+ :param input_filename:
+ :param output_filename_truncated: do not use a file ending for this one! Example: output_name='./converted/image1'. This
+ function will add the suffix (_0000) and file ending (.nii.gz) for you.
+ :param spacing:
+ :return:
+ """
+ img = io.imread(input_filename)
+
+ if transform is not None:
+ img = transform(img)
+
+ if len(img.shape) == 2: # 2d image with no color channels
+ img = img[None, None] # add dimensions
+ else:
+ assert len(img.shape) == 3, "image should be 3d with color channel last but has shape %s" % str(img.shape)
+ # we assume that the color channel is the last dimension. Transpose it to be in first
+ img = img.transpose((2, 0, 1))
+ # add third dimension
+ img = img[:, None]
+
+ # image is now (c, x, x, z) where x=1 since it's 2d
+ if is_seg:
+ assert img.shape[0] == 1, 'segmentations can only have one color channel, not sure what happened here'
+
+ for j, i in enumerate(img):
+
+ if is_seg:
+ i = i.astype(np.uint32)
+
+ itk_img = sitk.GetImageFromArray(i)
+ itk_img.SetSpacing(list(spacing)[::-1])
+ if not is_seg:
+ sitk.WriteImage(itk_img, output_filename_truncated + "_%04.0d.nii.gz" % j)
+ else:
+ sitk.WriteImage(itk_img, output_filename_truncated + ".nii.gz")
+
+
+def convert_3d_tiff_to_nifti(filenames: List[str], output_name: str, spacing: Union[tuple, list], transform=None, is_seg=False) -> None:
+ """
+ filenames must be a list of strings, each pointing to a separate 3d tiff file. One file per modality. If your data
+ only has one imaging modality, simply pass a list with only a single entry
+
+ Files in filenames must be readable with
+
+ Note: we always only pass one file into tifffile.imread, not multiple (even though it supports it). This is because
+ I am not familiar enough with this functionality and would like to have control over what happens.
+
+ If Transform is not None it will be applied to the image after loading.
+
+ :param transform:
+ :param filenames:
+ :param output_name:
+ :param spacing:
+ :return:
+ """
+ if is_seg:
+ assert len(filenames) == 1
+
+ for j, i in enumerate(filenames):
+ img = tifffile.imread(i)
+
+ if transform is not None:
+ img = transform(img)
+
+ itk_img = sitk.GetImageFromArray(img)
+ itk_img.SetSpacing(list(spacing)[::-1])
+
+ if not is_seg:
+ sitk.WriteImage(itk_img, output_name + "_%04.0d.nii.gz" % j)
+ else:
+ sitk.WriteImage(itk_img, output_name + ".nii.gz")
+
+
+def convert_2d_segmentation_nifti_to_img(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8):
+ img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file))
+ assert img.shape[0] == 1, "This function can only export 2D segmentations!"
+ img = img[0]
+ if transform is not None:
+ img = transform(img)
+
+ io.imsave(output_filename, img.astype(export_dtype), check_contrast=False)
+
+
+def convert_3d_segmentation_nifti_to_tiff(nifti_file: str, output_filename: str, transform=None, export_dtype=np.uint8):
+ img = sitk.GetArrayFromImage(sitk.ReadImage(nifti_file))
+ assert len(img.shape) == 3, "This function can only export 3D segmentations!"
+ if transform is not None:
+ img = transform(img)
+
+ tifffile.imsave(output_filename, img.astype(export_dtype))
diff --git a/nnunet/utilities/file_endings.py b/nnunet/utilities/file_endings.py
new file mode 100644
index 0000000000000000000000000000000000000000..a37c19758ecdea707449dc3d14ca56d8155c93bf
--- /dev/null
+++ b/nnunet/utilities/file_endings.py
@@ -0,0 +1,30 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+
+
+def remove_trailing_slash(filename: str):
+ while filename.endswith('/'):
+ filename = filename[:-1]
+ return filename
+
+
+def maybe_add_0000_to_all_niigz(folder):
+ nii_gz = subfiles(folder, suffix='.nii.gz')
+ for n in nii_gz:
+ n = remove_trailing_slash(n)
+ if not n.endswith('_0000.nii.gz'):
+ os.rename(n, n[:-7] + '_0000.nii.gz')
diff --git a/nnunet/utilities/folder_names.py b/nnunet/utilities/folder_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..708b132ebc62e7fdf1ba5f8e3a2c8baec44d33fb
--- /dev/null
+++ b/nnunet/utilities/folder_names.py
@@ -0,0 +1,47 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import network_training_output_dir
+
+
+def get_output_folder_name(model: str, task: str = None, trainer: str = None, plans: str = None, fold: int = None,
+ overwrite_training_output_dir: str = None):
+ """
+ Retrieves the correct output directory for the nnU-Net model described by the input parameters
+
+ :param model:
+ :param task:
+ :param trainer:
+ :param plans:
+ :param fold:
+ :param overwrite_training_output_dir:
+ :return:
+ """
+ assert model in ["2d", "3d_cascade_fullres", '3d_fullres', '3d_lowres']
+
+ if overwrite_training_output_dir is not None:
+ tr_dir = overwrite_training_output_dir
+ else:
+ tr_dir = network_training_output_dir
+
+ current = join(tr_dir, model)
+ if task is not None:
+ current = join(current, task)
+ if trainer is not None and plans is not None:
+ current = join(current, trainer + "__" + plans)
+ if fold is not None:
+ current = join(current, "fold_%d" % fold)
+ return current
diff --git a/nnunet/utilities/image_reorientation.py b/nnunet/utilities/image_reorientation.py
new file mode 100644
index 0000000000000000000000000000000000000000..068f0caa4e611eca5769121892501206ee009cb0
--- /dev/null
+++ b/nnunet/utilities/image_reorientation.py
@@ -0,0 +1,121 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import nibabel as nib
+from nibabel import io_orientation
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+import os
+from multiprocessing import Pool
+import SimpleITK as sitk
+
+
+def print_shapes(folder: str) -> None:
+ for i in subfiles(folder, suffix='.nii.gz'):
+ tmp = sitk.ReadImage(i)
+ print(sitk.GetArrayFromImage(tmp).shape, tmp.GetSpacing())
+
+
+def reorient_to_ras(image: str) -> None:
+ """
+ Will overwrite image!!!
+ :param image:
+ :return:
+ """
+ assert image.endswith('.nii.gz')
+ origaffine_pkl = image[:-7] + '_originalAffine.pkl'
+ if not isfile(origaffine_pkl):
+ img = nib.load(image)
+ original_affine = img.affine
+ original_axcode = nib.aff2axcodes(img.affine)
+ img = img.as_reoriented(io_orientation(img.affine))
+ new_axcode = nib.aff2axcodes(img.affine)
+ print(image.split('/')[-1], 'original axcode', original_axcode, 'now (should be ras)', new_axcode)
+ nib.save(img, image)
+ save_pickle((original_affine, original_axcode), origaffine_pkl)
+
+
+def revert_reorientation(image: str) -> None:
+ assert image.endswith('.nii.gz')
+ expected_pkl = image[:-7] + '_originalAffine.pkl'
+ assert isfile(expected_pkl), 'Must have a file with the original affine, as created by ' \
+ 'reorient_to_ras. Expected filename: %s' % \
+ expected_pkl
+ original_affine, original_axcode = load_pickle(image[:-7] + '_originalAffine.pkl')
+ img = nib.load(image)
+ before_revert = nib.aff2axcodes(img.affine)
+ img = img.as_reoriented(io_orientation(original_affine))
+ after_revert = nib.aff2axcodes(img.affine)
+ print('before revert', before_revert, 'after revert', after_revert)
+ restored_affine = img.affine
+ assert np.all(np.isclose(original_affine, restored_affine)), 'restored affine does not match original affine, ' \
+ 'aborting!'
+ nib.save(img, image)
+ os.remove(expected_pkl)
+
+
+def reorient_all_images_in_folder_to_ras(folder: str, num_processes: int = 8):
+ p = Pool(num_processes)
+ nii_files = subfiles(folder, suffix='.nii.gz', join=True)
+ p.map(reorient_to_ras, nii_files)
+ p.close()
+ p.join()
+
+
+def revert_orientation_on_all_images_in_folder(folder: str, num_processes: int = 8):
+ p = Pool(num_processes)
+ nii_files = subfiles(folder, suffix='.nii.gz', join=True)
+ p.map(revert_reorientation, nii_files)
+ p.close()
+ p.join()
+
+
+if __name__ == '__main__':
+ """nib.as_closest_canonical()
+ test_img = '/home/fabian/data/la_005_0000.nii.gz'
+ test_img_reorient = test_img[:-7] + '_reorient.nii.gz'
+ test_img_restored = test_img[:-7] + '_restored.nii.gz'
+
+ img = nib.load(test_img)
+ print('loaded original')
+ print('shape', img.shape)
+ print('affine', img.affine)
+ original_affine = img.affine
+ original_axcode = nib.aff2axcodes(img.affine)
+ print('orientation', nib.aff2axcodes(img.affine))
+
+ print('reorienting...')
+ img = img.as_reoriented(io_orientation(img.affine))
+ nib.save(img, test_img_reorient)
+
+ print('now loading the reoriented img')
+ img = nib.load(test_img_reorient)
+ print('loaded original')
+ print('shape', img.shape)
+ print('affine', img.affine)
+ reorient_affine = img.affine
+ reorient_axcode = nib.aff2axcodes(img.affine)
+ print('orientation', nib.aff2axcodes(img.affine))
+
+ print('restoring original geometry')
+ img = img.as_reoriented(io_orientation(original_affine))
+ restored_affine = img.affine
+ nib.save(img, test_img_restored)
+
+ print('now loading the restored img')
+ img = nib.load(test_img_restored)
+ print('loaded original')
+ print('shape', img.shape)
+ print('affine', img.affine)
+ print('orientation', nib.aff2axcodes(img.affine))"""
diff --git a/nnunet/utilities/nd_softmax.py b/nnunet/utilities/nd_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..98f3161a1af71dc364b74a56db0d930371206ccb
--- /dev/null
+++ b/nnunet/utilities/nd_softmax.py
@@ -0,0 +1,21 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+
+
+softmax_helper = lambda x: F.softmax(x, 1)
+
diff --git a/nnunet/utilities/one_hot_encoding.py b/nnunet/utilities/one_hot_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c5e95b00cfe5e5d3b37934895b833a40f3514fc
--- /dev/null
+++ b/nnunet/utilities/one_hot_encoding.py
@@ -0,0 +1,24 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+
+
+def to_one_hot(seg, all_seg_labels=None):
+ if all_seg_labels is None:
+ all_seg_labels = np.unique(seg)
+ result = np.zeros((len(all_seg_labels), *seg.shape), dtype=seg.dtype)
+ for i, l in enumerate(all_seg_labels):
+ result[i][seg == l] = 1
+ return result
diff --git a/nnunet/utilities/overlay_plots.py b/nnunet/utilities/overlay_plots.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9b399173e7d31fddba98175107250bcad0c76df
--- /dev/null
+++ b/nnunet/utilities/overlay_plots.py
@@ -0,0 +1,204 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from multiprocessing.pool import Pool
+
+import numpy as np
+import SimpleITK as sitk
+from nnunet.utilities.task_name_id_conversion import convert_task_name_to_id, convert_id_to_task_name
+from batchgenerators.utilities.file_and_folder_operations import *
+from nnunet.paths import *
+
+color_cycle = (
+ "000000",
+ "4363d8",
+ "f58231",
+ "3cb44b",
+ "e6194B",
+ "911eb4",
+ "ffe119",
+ "bfef45",
+ "42d4f4",
+ "f032e6",
+ "000075",
+ "9A6324",
+ "808000",
+ "800000",
+ "469990",
+)
+
+
+def hex_to_rgb(hex: str):
+ assert len(hex) == 6
+ return tuple(int(hex[i:i + 2], 16) for i in (0, 2, 4))
+
+
+def generate_overlay(input_image: np.ndarray, segmentation: np.ndarray, mapping: dict = None, color_cycle=color_cycle,
+ overlay_intensity=0.6):
+ """
+ image must be a color image, so last dimension must be 3. if image is grayscale, tile it first!
+ Segmentation must be label map of same shape as image (w/o color channels)
+ mapping can be label_id -> idx_in_cycle or None
+
+ returned image is scaled to [0, 255]!!!
+ """
+ # assert len(image.shape) == len(segmentation.shape)
+ # assert all([i == j for i, j in zip(image.shape, segmentation.shape)])
+
+ # create a copy of image
+ image = np.copy(input_image)
+
+ if len(image.shape) == 2:
+ image = np.tile(image[:, :, None], (1, 1, 3))
+ elif len(image.shape) == 3:
+ assert image.shape[2] == 3, 'if 3d image is given the last dimension must be the color channels ' \
+ '(3 channels). Only 2D images are supported'
+
+ else:
+ raise RuntimeError("unexpected image shape. only 2D images and 2D images with color channels (color in "
+ "last dimension) are supported")
+
+ # rescale image to [0, 255]
+ image = image - image.min()
+ image = image / image.max() * 255
+
+ # create output
+
+ if mapping is None:
+ uniques = np.unique(segmentation)
+ mapping = {i: c for c, i in enumerate(uniques)}
+
+ for l in mapping.keys():
+ image[segmentation == l] += overlay_intensity * np.array(hex_to_rgb(color_cycle[mapping[l]]))
+
+ # rescale result to [0, 255]
+ image = image / image.max() * 255
+ return image.astype(np.uint8)
+
+
+def plot_overlay(image_file: str, segmentation_file: str, output_file: str, overlay_intensity: float = 0.6):
+ import matplotlib.pyplot as plt
+
+ image = sitk.GetArrayFromImage(sitk.ReadImage(image_file))
+ seg = sitk.GetArrayFromImage(sitk.ReadImage(segmentation_file))
+ assert all([i == j for i, j in zip(image.shape, seg.shape)]), "image and seg do not have the same shape: %s, %s" % (
+ image_file, segmentation_file)
+
+ assert len(image.shape) == 3, 'only 3D images/segs are supported'
+
+ fg_mask = seg != 0
+ fg_per_slice = fg_mask.sum((1, 2))
+ selected_slice = np.argmax(fg_per_slice)
+
+ overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
+
+ plt.imsave(output_file, overlay)
+
+
+def plot_overlay_preprocessed(case_file: str, output_file: str, overlay_intensity: float = 0.6, modality_index=0):
+ import matplotlib.pyplot as plt
+ data = np.load(case_file)['data']
+
+ assert modality_index < (data.shape[0] - 1), 'This dataset only supports modality index up to %d' % (data.shape[0] - 2)
+
+ image = data[modality_index]
+ seg = data[-1]
+ seg[seg < 0] = 0
+
+ fg_mask = seg > 0
+ fg_per_slice = fg_mask.sum((1, 2))
+ selected_slice = np.argmax(fg_per_slice)
+
+ overlay = generate_overlay(image[selected_slice], seg[selected_slice], overlay_intensity=overlay_intensity)
+
+ plt.imsave(output_file, overlay)
+
+
+def multiprocessing_plot_overlay(list_of_image_files, list_of_seg_files, list_of_output_files, overlay_intensity,
+ num_processes=8):
+ p = Pool(num_processes)
+ r = p.starmap_async(plot_overlay, zip(
+ list_of_image_files, list_of_seg_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files)
+ ))
+ r.get()
+ p.close()
+ p.join()
+
+
+def multiprocessing_plot_overlay_preprocessed(list_of_case_files, list_of_output_files, overlay_intensity,
+ num_processes=8, modality_index=0):
+ p = Pool(num_processes)
+ r = p.starmap_async(plot_overlay_preprocessed, zip(
+ list_of_case_files, list_of_output_files, [overlay_intensity] * len(list_of_output_files),
+ [modality_index] * len(list_of_output_files)
+ ))
+ r.get()
+ p.close()
+ p.join()
+
+
+def generate_overlays_for_task(task_name_or_id, output_folder, num_processes=8, modality_idx=0, use_preprocessed=True,
+ data_identifier=default_data_identifier):
+ if isinstance(task_name_or_id, str):
+ if not task_name_or_id.startswith("Task"):
+ task_name_or_id = int(task_name_or_id)
+ task_name = convert_id_to_task_name(task_name_or_id)
+ else:
+ task_name = task_name_or_id
+ else:
+ task_name = convert_id_to_task_name(int(task_name_or_id))
+
+ if not use_preprocessed:
+ folder = join(nnUNet_raw_data, task_name)
+
+ identifiers = [i[:-7] for i in subfiles(join(folder, 'labelsTr'), suffix='.nii.gz', join=False)]
+
+ image_files = [join(folder, 'imagesTr', i + "_%04.0d.nii.gz" % modality_idx) for i in identifiers]
+ seg_files = [join(folder, 'labelsTr', i + ".nii.gz") for i in identifiers]
+
+ assert all([isfile(i) for i in image_files])
+ assert all([isfile(i) for i in seg_files])
+
+ maybe_mkdir_p(output_folder)
+ output_files = [join(output_folder, i + '.png') for i in identifiers]
+ multiprocessing_plot_overlay(image_files, seg_files, output_files, 0.6, num_processes)
+ else:
+ folder = join(preprocessing_output_dir, task_name)
+ if not isdir(folder): raise RuntimeError("run preprocessing for that task first")
+ matching_folders = subdirs(folder, prefix=data_identifier + "_stage")
+ if len(matching_folders) == 0: "run preprocessing for that task first (use default experiment planner!)"
+ matching_folders.sort()
+ folder = matching_folders[-1]
+ identifiers = [i[:-4] for i in subfiles(folder, suffix='.npz', join=False)]
+ maybe_mkdir_p(output_folder)
+ output_files = [join(output_folder, i + '.png') for i in identifiers]
+ image_files = [join(folder, i + ".npz") for i in identifiers]
+ maybe_mkdir_p(output_folder)
+ multiprocessing_plot_overlay_preprocessed(image_files, output_files, overlay_intensity=0.6,
+ num_processes=num_processes, modality_index=modality_idx)
+
+
+def entry_point_generate_overlay():
+ import argparse
+ parser = argparse.ArgumentParser("Plots png overlays of the slice with the most foreground. Note that this "
+ "disregards spacing information!")
+ parser.add_argument('-t', type=str, help="task name or task ID", required=True)
+ parser.add_argument('-o', type=str, help="output folder", required=True)
+ parser.add_argument('-num_processes', type=int, default=8, required=False, help="number of processes used. Default: 8")
+ parser.add_argument('-modality_idx', type=int, default=0, required=False,
+ help="modality index used (0 = _0000.nii.gz). Default: 0")
+ parser.add_argument('--use_raw', action='store_true', required=False, help="if set then we use raw data. else "
+ "we use preprocessed")
+ args = parser.parse_args()
+
+ generate_overlays_for_task(args.t, args.o, args.num_processes, args.modality_idx, use_preprocessed=not args.use_raw)
\ No newline at end of file
diff --git a/nnunet/utilities/random_stuff.py b/nnunet/utilities/random_stuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..b94db933774a0ff1dfdec89af3c3b8f40804728a
--- /dev/null
+++ b/nnunet/utilities/random_stuff.py
@@ -0,0 +1,21 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+class no_op(object):
+ def __enter__(self):
+ pass
+
+ def __exit__(self, *args):
+ pass
diff --git a/nnunet/utilities/recursive_delete_npz.py b/nnunet/utilities/recursive_delete_npz.py
new file mode 100644
index 0000000000000000000000000000000000000000..60428778a73f336ade51805176b89e5780fc2384
--- /dev/null
+++ b/nnunet/utilities/recursive_delete_npz.py
@@ -0,0 +1,37 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import argparse
+import os
+
+
+def recursive_delete_npz(current_directory: str):
+ npz_files = subfiles(current_directory, join=True, suffix=".npz")
+ npz_files = [i for i in npz_files if not i.endswith("segFromPrevStage.npz")] # to be extra safe
+ _ = [os.remove(i) for i in npz_files]
+ for d in subdirs(current_directory, join=False):
+ if d != "pred_next_stage":
+ recursive_delete_npz(join(current_directory, d))
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(usage="USE THIS RESPONSIBLY! DANGEROUS! I (Fabian) use this to remove npz files "
+ "after I ran figure_out_what_to_submit")
+ parser.add_argument("-f", help="folder", required=True)
+
+ args = parser.parse_args()
+
+ recursive_delete_npz(args.f)
diff --git a/nnunet/utilities/recursive_rename_taskXX_to_taskXXX.py b/nnunet/utilities/recursive_rename_taskXX_to_taskXXX.py
new file mode 100644
index 0000000000000000000000000000000000000000..569bd545345fc6419311267f9596cc4d75cf70fc
--- /dev/null
+++ b/nnunet/utilities/recursive_rename_taskXX_to_taskXXX.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from batchgenerators.utilities.file_and_folder_operations import *
+import os
+
+
+def recursive_rename(folder):
+ s = subdirs(folder, join=False)
+ for ss in s:
+ if ss.startswith("Task") and ss.find("_") == 6:
+ task_id = int(ss[4:6])
+ name = ss[7:]
+ os.rename(join(folder, ss), join(folder, "Task%03.0d_" % task_id + name))
+ s = subdirs(folder, join=True)
+ for ss in s:
+ recursive_rename(ss)
+
+if __name__ == "__main__":
+ recursive_rename("/media/fabian/Results/nnUNet")
+ recursive_rename("/media/fabian/nnunet")
+ recursive_rename("/media/fabian/My Book/MedicalDecathlon")
+ recursive_rename("/home/fabian/drives/datasets/nnUNet_raw")
+ recursive_rename("/home/fabian/drives/datasets/nnUNet_preprocessed")
+ recursive_rename("/home/fabian/drives/datasets/nnUNet_testSets")
+ recursive_rename("/home/fabian/drives/datasets/results/nnUNet")
+ recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/Decathlon_raw")
+ recursive_rename("/home/fabian/drives/e230-dgx2-1-data_fabian/nnUNet_preprocessed")
+
diff --git a/nnunet/utilities/set_n_proc_DA.py b/nnunet/utilities/set_n_proc_DA.py
new file mode 100644
index 0000000000000000000000000000000000000000..09aff913788960f736cb63a85e8df025fa532834
--- /dev/null
+++ b/nnunet/utilities/set_n_proc_DA.py
@@ -0,0 +1,41 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import subprocess
+import os
+
+
+def get_allowed_n_proc_DA():
+ hostname = subprocess.getoutput(['hostname'])
+
+ if 'nnUNet_n_proc_DA' in os.environ.keys():
+ return int(os.environ['nnUNet_n_proc_DA'])
+
+ if hostname in ['hdf19-gpu16', 'hdf19-gpu17', 'e230-AMDworkstation']:
+ return 16
+
+ if hostname in ['Fabian',]:
+ return 12
+
+ if hostname.startswith('hdf19-gpu') or hostname.startswith('e071-gpu'):
+ return 12
+ elif hostname.startswith('e230-dgx1'):
+ return 10
+ elif hostname.startswith('hdf18-gpu') or hostname.startswith('e132-comp'):
+ return 16
+ elif hostname.startswith('e230-dgx2'):
+ return 6
+ elif hostname.startswith('e230-dgxa100-'):
+ return 32
+ else:
+ return None
\ No newline at end of file
diff --git a/nnunet/utilities/sitk_stuff.py b/nnunet/utilities/sitk_stuff.py
new file mode 100644
index 0000000000000000000000000000000000000000..36d38db8f604e5656a3ea7417b12ac857aefdd2b
--- /dev/null
+++ b/nnunet/utilities/sitk_stuff.py
@@ -0,0 +1,23 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+import SimpleITK as sitk
+
+
+def copy_geometry(image: sitk.Image, ref: sitk.Image):
+ image.SetOrigin(ref.GetOrigin())
+ image.SetDirection(ref.GetDirection())
+ image.SetSpacing(ref.GetSpacing())
+ return image
diff --git a/nnunet/utilities/task_name_id_conversion.py b/nnunet/utilities/task_name_id_conversion.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a0b22a5f61c6813d604de1165d5f52bfc4ed07f
--- /dev/null
+++ b/nnunet/utilities/task_name_id_conversion.py
@@ -0,0 +1,67 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+from nnunet.paths import nnUNet_raw_data, preprocessing_output_dir, nnUNet_cropped_data, network_training_output_dir
+from batchgenerators.utilities.file_and_folder_operations import *
+import numpy as np
+
+
+def convert_id_to_task_name(task_id: int):
+ startswith = "Task%03.0d" % task_id
+ if preprocessing_output_dir is not None:
+ candidates_preprocessed = subdirs(preprocessing_output_dir, prefix=startswith, join=False)
+ else:
+ candidates_preprocessed = []
+
+ if nnUNet_raw_data is not None:
+ candidates_raw = subdirs(nnUNet_raw_data, prefix=startswith, join=False)
+ else:
+ candidates_raw = []
+
+ if nnUNet_cropped_data is not None:
+ candidates_cropped = subdirs(nnUNet_cropped_data, prefix=startswith, join=False)
+ else:
+ candidates_cropped = []
+
+ candidates_trained_models = []
+ if network_training_output_dir is not None:
+ for m in ['2d', '3d_lowres', '3d_fullres', '3d_cascade_fullres']:
+ if isdir(join(network_training_output_dir, m)):
+ candidates_trained_models += subdirs(join(network_training_output_dir, m), prefix=startswith, join=False)
+
+ all_candidates = candidates_cropped + candidates_preprocessed + candidates_raw + candidates_trained_models
+ unique_candidates = np.unique(all_candidates)
+ if len(unique_candidates) > 1:
+ raise RuntimeError("More than one task name found for task id %d. Please correct that. (I looked in the "
+ "following folders:\n%s\n%s\n%s" % (task_id, nnUNet_raw_data, preprocessing_output_dir,
+ nnUNet_cropped_data))
+ if len(unique_candidates) == 0:
+ raise RuntimeError("Could not find a task with the ID %d. Make sure the requested task ID exists and that "
+ "nnU-Net knows where raw and preprocessed data are located (see Documentation - "
+ "Installation). Here are your currently defined folders:\nnnUNet_preprocessed=%s\nRESULTS_"
+ "FOLDER=%s\nnnUNet_raw_data_base=%s\nIf something is not right, adapt your environemnt "
+ "variables." %
+ (task_id,
+ os.environ.get('nnUNet_preprocessed') if os.environ.get('nnUNet_preprocessed') is not None else 'None',
+ os.environ.get('RESULTS_FOLDER') if os.environ.get('RESULTS_FOLDER') is not None else 'None',
+ os.environ.get('nnUNet_raw_data_base') if os.environ.get('nnUNet_raw_data_base') is not None else 'None',
+ ))
+ return unique_candidates[0]
+
+
+def convert_task_name_to_id(task_name: str):
+ assert task_name.startswith("Task")
+ task_id = int(task_name[4:7])
+ return task_id
diff --git a/nnunet/utilities/tensor_utilities.py b/nnunet/utilities/tensor_utilities.py
new file mode 100644
index 0000000000000000000000000000000000000000..daded59b43f87762a90852222325a5eed5be9f9a
--- /dev/null
+++ b/nnunet/utilities/tensor_utilities.py
@@ -0,0 +1,54 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+from torch import nn
+
+
+def sum_tensor(inp, axes, keepdim=False):
+ axes = np.unique(axes).astype(int)
+ if keepdim:
+ for ax in axes:
+ inp = inp.sum(int(ax), keepdim=True)
+ else:
+ for ax in sorted(axes, reverse=True):
+ inp = inp.sum(int(ax))
+ return inp
+
+
+def mean_tensor(inp, axes, keepdim=False):
+ axes = np.unique(axes).astype(int)
+ if keepdim:
+ for ax in axes:
+ inp = inp.mean(int(ax), keepdim=True)
+ else:
+ for ax in sorted(axes, reverse=True):
+ inp = inp.mean(int(ax))
+ return inp
+
+
+def flip(x, dim):
+ """
+ flips the tensor at dimension dim (mirroring!)
+ :param x:
+ :param dim:
+ :return:
+ """
+ indices = [slice(None)] * x.dim()
+ indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
+ dtype=torch.long, device=x.device)
+ return x[tuple(indices)]
+
+
diff --git a/nnunet/utilities/to_torch.py b/nnunet/utilities/to_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab68035eb19774540b7ca46a177a4d26d6fb3a4f
--- /dev/null
+++ b/nnunet/utilities/to_torch.py
@@ -0,0 +1,31 @@
+# Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import torch
+
+
+def maybe_to_torch(d):
+ if isinstance(d, list):
+ d = [maybe_to_torch(i) if not isinstance(i, torch.Tensor) else i for i in d]
+ elif not isinstance(d, torch.Tensor):
+ d = torch.from_numpy(d).float()
+ return d
+
+
+def to_cuda(data, non_blocking=True, gpu_id=0):
+ if isinstance(data, list):
+ data = [i.cuda(gpu_id, non_blocking=non_blocking) for i in data]
+ else:
+ data = data.cuda(gpu_id, non_blocking=non_blocking)
+ return data
diff --git a/preprocessing/cut_flip.py b/preprocessing/cut_flip.py
new file mode 100644
index 0000000000000000000000000000000000000000..a258e561279d44a2ed8356aedd525b7acc4c8cb6
--- /dev/null
+++ b/preprocessing/cut_flip.py
@@ -0,0 +1,108 @@
+import numpy as np
+import glob
+import ants
+import nibabel as nib
+import os
+import argparse
+import sys
+from pathlib import Path
+
+def parse_command_line():
+ parser = argparse.ArgumentParser(
+ description='pipeline for data preprocessing')
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="absolute path of the base directory")
+ parser.add_argument('-ip', metavar='image path', type=str,
+ help="relative path of the image directory")
+ parser.add_argument('-sp', metavar='segmentation path', type=str,
+ help="relative path of the image directory")
+ parser.add_argument('-op', metavar='preprocessing result output path', type=str, default='output',
+ help='relative path of the preprocessing result directory')
+ argv = parser.parse_args()
+ return argv
+
+def flip(nib_img, nib_seg, ants_img, ants_seg, seg_fomat):
+ img = nib_img.get_fdata()
+ if seg_fomat == 'nii.gz' or seg_fomat == 'nii':
+ seg = nib_seg.get_fdata()
+ else:
+ seg = nib_seg[0]
+ gem = ants.label_geometry_measures(ants_seg, ants_img)
+ low_x = min(list(gem.loc[:, 'BoundingBoxLower_x']))
+ upp_x = max(list(gem.loc[:, 'BoundingBoxUpper_x']))
+ low_y = min(list(gem.loc[:, 'BoundingBoxLower_y']))
+ upp_y = max(list(gem.loc[:, 'BoundingBoxUpper_y']))
+ low_z = min(list(gem.loc[:, 'BoundingBoxLower_z']))
+ upp_z = max(list(gem.loc[:, 'BoundingBoxUpper_z']))
+ # Compute mid point
+ mid_x = int((low_x + upp_x) / 2)
+
+ left_seg = seg[:mid_x, :, :]
+ left_img = img[:mid_x, :, :]
+ right_seg = seg[mid_x:, :, :]
+ right_img = img[mid_x:, :, :]
+ flipped_right_seg = np.flip(right_seg, axis=0)
+ flipped_right_img = np.flip(right_img, axis=0)
+ print("finish flip")
+ return left_img, left_seg, flipped_right_img, flipped_right_seg
+
+def load_data(img_path, seg_path):
+ nib_seg = nib.load(seg_path)
+ nib_img = nib.load(img_path)
+ ants_seg = ants.image_read(seg_path)
+ ants_img = ants.image_read(img_path)
+ return nib_img, nib_seg, ants_img, ants_seg
+
+
+def crop_flip_save_file(left_img, left_seg, flipped_right_img, flipped_right_seg, nib_img, nib_seg, output_img, output_seg, scan_id):
+ left_img_nii = nib.Nifti1Image(
+ left_img, affine=nib_img.affine, header=nib_img.header)
+ left_seg_nii = nib.Nifti1Image(
+ left_seg, affine=nib_seg.affine, header=nib_seg.header)
+ right_img_nii = nib.Nifti1Image(
+ flipped_right_img, affine=nib_img.affine, header=nib_img.header)
+ right_seg_nii = nib.Nifti1Image(
+ flipped_right_seg, affine=nib_seg.affine, header=nib_seg.header)
+ left_img_nii.to_filename(os.path.join(
+ output_img, scan_id + '1.nii.gz'))
+ left_seg_nii.to_filename(os.path.join(
+ output_seg, scan_id + '1.nii.gz'))
+ right_img_nii.to_filename(os.path.join(
+ output_img, scan_id + '0.nii.gz'))
+ right_seg_nii.to_filename(os.path.join(
+ output_seg, scan_id + '0.nii.gz'))
+
+
+def main():
+ args = parse_command_line()
+ base_path = args.bp
+ image_path = os.path.join(base_path, args.ip)
+ seg_path = os.path.join(base_path, args.sp)
+ output_path = os.path.join(base_path, args.op)
+ output_img = os.path.join(output_path, 'images')
+ output_seg = os.path.join(output_path, 'labels')
+ try:
+ os.mkdir(output_path)
+ except:
+ print(f'{output_path} is already existed')
+
+ try:
+ os.mkdir(output_img)
+ except:
+ print(f'{output_img} is already existed')
+
+ try:
+ os.mkdir(output_seg)
+ except:
+ print(f'{output_seg} is already existed')
+
+ for i in sorted(glob.glob(image_path + '/*nii.gz')):
+ id = os.path.basename(i).split('.')[0]
+ label_path = os.path.join(seg_path, id + '.nii.gz')
+ nib_img, nib_seg, ants_img, ants_seg = load_data(i, label_path)
+ left_img, left_seg, flipped_right_img, flipped_right_seg = flip(nib_img, nib_seg, ants_img, ants_seg, 'nii.gz')
+ print('Scan ID: ' + id + f', img & seg before cropping: {nib_img.get_fdata().shape}, after flipping: {left_img.shape} and {flipped_right_img.shape}')
+ crop_flip_save_file(left_img, left_seg, flipped_right_img, flipped_right_seg, nib_img, nib_seg, output_img, output_seg, id)
+
+if __name__ == '__main__':
+ main()
\ No newline at end of file
diff --git a/preprocessing/deface.py b/preprocessing/deface.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ddca0b1435f06a4b808f05b87467686c8a40e19
--- /dev/null
+++ b/preprocessing/deface.py
@@ -0,0 +1,67 @@
+import numpy as np
+import nibabel as nib
+import os
+from glob import glob
+import argparse
+
+
+def parse_command_line():
+ print('---'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(
+ description='Defacing protocol')
+ parser.add_argument('-sc', metavar='Scans', type=str,
+ help="An integer belonging to the scan ids you wish to choose as template")
+ parser.add_argument('-mk', metavar='Masks', type=str,
+ help="An integer belonging to the scan ids you wish to choose as template segmentation id")
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ argv = parser.parse_args()
+ return argv
+
+
+def deface(input_file, mask_file, output_file=None, suffix=" (masked)", write=True):
+ # Load the original CT volume
+ input = nib.load(input_file)
+
+ # Load the segmentation mask
+ segmentation = nib.load(mask_file)
+
+ input_array = input.get_fdata()
+ segmentation_array = segmentation.get_fdata()
+ mask = 1-segmentation_array # 0's inside the mask, 1's outside
+
+ # Create the masked CT volume
+ output_array = input_array * mask
+ output = nib.Nifti1Image(output_array, input.affine, input.header)
+
+ # Save the masked CT volume
+ if write:
+ if output_file is None: # Save in same folder but with suffix
+ output_file = input_file.split(".")[0] + suffix + ".nii.gz"
+ output.to_filename(output_file)
+ else:
+ # Otherwise, save to specified path
+ output.to_filename(output_file)
+ return output
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ images = args.sc
+ masks = args.mk
+
+ CT_images = sorted(glob(os.path.join(base, images, '*.nii.gz')))
+ mask_images = sorted(glob(os.path.join(base, masks, '*.nii.gz')))
+
+ num = len(CT_images)
+ print(num)
+
+ for i in range(num):
+ deface(CT_images[i], mask_images[i], output_file=None,
+ suffix=' (masked)', write=True)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/preprocessing/registration.py b/preprocessing/registration.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a483c3b754e5da3065235751c1df556fbd071c5
--- /dev/null
+++ b/preprocessing/registration.py
@@ -0,0 +1,380 @@
+import os
+import ants
+import nrrd
+import numpy as np
+import glob
+import slicerio
+import shutil
+import argparse
+
+
+def parse_command_line():
+ print('-----'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(
+ description='pipeline for dataset nnUNet preprocessing')
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ parser.add_argument('-ip', metavar='image path', type=str,
+ help="Relative path of the image directory")
+ parser.add_argument('-sp', metavar='segmentation path', type=str,
+ help="Relative path of the image directory")
+ parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
+ help='a list of label name and corresponding value')
+ argv = parser.parse_args()
+ return argv
+
+
+def split_and_registration(template, target, base, images_path, seg_path, fomat, checked=False, has_label=False):
+ print('-----'*10)
+ print('Creating file paths')
+ # Define the path for template, target, and segmentations (from template)
+ fixed_path = os.path.join(base, images_path, template + '.' + fomat)
+ moving_path = os.path.join(base, images_path, target + '.' + fomat)
+ images_output = os.path.join(base, 'imagesRS/', target + '.nii.gz')
+ print('-----'*10)
+ print(f'Reading in the template {template} and target {target} image')
+ # Read the template and target image
+ template_image = ants.image_read(fixed_path)
+ target_image = ants.image_read(moving_path)
+ print('-----'*10)
+ print('Performing the template and target image registration')
+ transform_forward = ants.registration(fixed=template_image, moving=target_image,
+ type_of_transform="AffineFast", verbose=False)
+ if has_label:
+ segmentation_path = os.path.join(
+ base, seg_path, target + '.nii.gz')
+ segmentation_output = os.path.join(
+ base, 'labelsRS/', target + '.nii.gz')
+ print('-----'*10)
+ print('Reading in the segmentation')
+ # Split segmentations into individual components
+ segment_target = ants.image_read(segmentation_path)
+ print('-----'*10)
+ print('Applying the transformation for label propagation and image registration')
+ predicted_targets_image = ants.apply_transforms(
+ fixed=template_image,
+ moving=segment_target,
+ transformlist=transform_forward["fwdtransforms"],
+ interpolator="genericLabel",
+ verbose=False)
+ predicted_targets_image.to_file(segmentation_output)
+
+ reg_img = ants.apply_transforms(
+ fixed=template_image,
+ moving=target_image,
+ transformlist=transform_forward["fwdtransforms"],
+ interpolator="linear",
+ verbose=False)
+ print('-----'*10)
+ print("writing out transformed template segmentation")
+ reg_img.to_file(images_output)
+ print('Label Propagation & Image Registration complete')
+
+
+def convert_to_one_hot(data, header, segment_indices=None):
+ print('-----'*10)
+ print("converting to one hot")
+
+ layer_values = get_layer_values(header)
+ label_values = get_label_values(header)
+
+ # Newer Slicer NRRD (compressed layers)
+ if layer_values and label_values:
+
+ assert len(layer_values) == len(label_values)
+ if len(data.shape) == 3:
+ x_dim, y_dim, z_dim = data.shape
+ elif len(data.shape) == 4:
+ x_dim, y_dim, z_dim = data.shape[1:]
+
+ num_segments = len(layer_values)
+ one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
+
+ if segment_indices is None:
+ segment_indices = list(range(num_segments))
+
+ elif isinstance(segment_indices, int):
+ segment_indices = [segment_indices]
+
+ elif not isinstance(segment_indices, list):
+ print("incorrectly specified segment indices")
+ return
+
+ # Check if NRRD is composed of one layer 0
+ if np.max(layer_values) == 0:
+ for i, seg_idx in enumerate(segment_indices):
+ layer = layer_values[seg_idx]
+ label = label_values[seg_idx]
+ one_hot[i] = 1*(data == label).astype(np.uint8)
+
+ else:
+ for i, seg_idx in enumerate(segment_indices):
+ layer = layer_values[seg_idx]
+ label = label_values[seg_idx]
+ one_hot[i] = 1*(data[layer] == label).astype(np.uint8)
+
+ # Binary labelmap
+ elif len(data.shape) == 3:
+ x_dim, y_dim, z_dim = data.shape
+ num_segments = np.max(data)
+ one_hot = np.zeros((num_segments, x_dim, y_dim, z_dim))
+
+ if segment_indices is None:
+ segment_indices = list(range(1, num_segments + 1))
+
+ elif isinstance(segment_indices, int):
+ segment_indices = [segment_indices]
+
+ elif not isinstance(segment_indices, list):
+ print("incorrectly specified segment indices")
+ return
+
+ for i, seg_idx in enumerate(segment_indices):
+ one_hot[i] = 1*(data == seg_idx).astype(np.uint8)
+
+ # Older Slicer NRRD (already one-hot)
+ else:
+ return data
+
+ return one_hot
+
+
+def get_layer_values(header, indices=None):
+ layer_values = []
+ num_segments = len([key for key in header.keys() if "Layer" in key])
+ for i in range(num_segments):
+ layer_values.append(int(header['Segment{}_Layer'.format(i)]))
+ return layer_values
+
+
+def get_label_values(header, indices=None):
+ label_values = []
+ num_segments = len([key for key in header.keys() if "LabelValue" in key])
+ for i in range(num_segments):
+ label_values.append(int(header['Segment{}_LabelValue'.format(i)]))
+ return label_values
+
+
+def get_num_segments(header, indices=None):
+ num_segments = len([key for key in header.keys() if "LabelValue" in key])
+ return num_segments
+
+
+def checkCorrespondence(segmentation, base, paired_list, filename):
+ print(filename)
+ assert type(paired_list) == list
+ data, tempSeg = nrrd.read(os.path.join(base, segmentation, filename))
+ seg_info = slicerio.read_segmentation_info(
+ os.path.join(base, segmentation, filename))
+ output_voxels, output_header = slicerio.extract_segments(
+ data, tempSeg, seg_info, paired_list)
+ output = os.path.join(base, 'MatchedSegs/' +
+ filename)
+ nrrd.write(output, output_voxels, output_header)
+ print('---'*10)
+ print('Check the label names and values')
+ print(slicerio.read_segmentation_info(output))
+ return output
+
+
+def checkSegFormat(base, segmentation, paired_list, check=False):
+ path = os.path.join(base, segmentation)
+ save_dir = os.path.join(base, 're-format_labels')
+ try:
+ os.mkdir(save_dir)
+ except:
+ print(f'{save_dir} already exists')
+
+ for file in os.listdir(path):
+ name = file.split('.')[0]
+ if file.endswith('seg.nrrd') or file.endswith('nrrd'):
+ if check:
+ output_path = checkCorrespondence(
+ segmentation, base, paired_list, file)
+ ants_img = ants.image_read(output_path)
+ header = nrrd.read_header(output_path)
+ else:
+ ants_img = ants.image_read(os.path.join(path, file))
+ header = nrrd.read_header(os.path.join(path, file))
+ segmentations = True
+ filename = os.path.join(save_dir, name + '.nii.gz')
+ nrrd2nifti(ants_img, header, filename, segmentations)
+ elif file.endswith('nii'):
+ image = ants.image_read(os.path.join(path, file))
+ image.to_file(os.path.join(save_dir, name + '.nii.gz'))
+ elif file.endswith('nii.gz'):
+ shutil.copy(os.path.join(path, file), save_dir)
+
+ return save_dir
+
+
+def nrrd2nifti(img, header, filename, segmentations=True):
+ img_as_np = img.view(single_components=segmentations)
+ if segmentations:
+ data = convert_to_one_hot(img_as_np, header)
+ foreground = np.max(data, axis=0)
+ labelmap = np.multiply(np.argmax(data, axis=0) + 1,
+ foreground).astype('uint8')
+ segmentation_img = ants.from_numpy(
+ labelmap, origin=img.origin, spacing=img.spacing, direction=img.direction)
+ print('-- Saving NII Segmentations')
+ segmentation_img.to_file(filename)
+ else:
+ print('-- Saving NII Volume')
+ img.to_file(filename)
+
+
+def find_template(base, image_path, fomat):
+ scans = sorted(glob.glob(os.path.join(base, image_path) + '/*' + fomat))
+ template = os.path.basename(scans[0]).split('.')[0]
+ return template
+
+
+def find_template_V2(base, image_path, fomat):
+ maxD = -np.inf
+ for i in glob.glob(os.path.join(base, image_path) + '/*' + fomat):
+ id = os.path.basename(i).split('.')[0]
+ img = ants.image_read(i)
+ thirdD = img.shape[2]
+ if thirdD > maxD:
+ template = id
+ maxD = thirdD
+ print(maxD, template)
+ return template
+
+
+def path_to_id(path, fomat):
+ ids = []
+ for i in glob.glob(path + '/*' + fomat):
+ id = os.path.basename(i).split('.')[0]
+ ids.append(id)
+ return ids
+
+
+def checkFormat(base, images_path):
+ path = os.path.join(base, images_path)
+ for file in os.listdir(path):
+ if file.endswith('.nii'):
+ ret = 'nii'
+ break
+ elif file.endswith('.nii.gz'):
+ ret = 'nii.gz'
+ break
+ elif file.endswith('.nrrd'):
+ ret = 'nrrd'
+ break
+ elif file.endswith('.seg.nrrd'):
+ ret = 'seg.nrrd'
+ break
+ return ret
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ images_path = args.ip
+ segmentation = args.sp
+ label_list = args.sl
+ images_output = os.path.join(base, 'imagesRS')
+ labels_output = os.path.join(base, 'labelsRS')
+ fomat = checkFormat(base, images_path)
+ fomat_seg = checkFormat(base, segmentation)
+ template = find_template(base, images_path, fomat)
+ label_lists = path_to_id(os.path.join(base, segmentation), fomat_seg)
+ if label_list is not None:
+ matched_output = os.path.join(base, 'MatchedSegs')
+ try:
+ os.mkdir(matched_output)
+ except:
+ print(f"{matched_output} already exists")
+
+ try:
+ os.mkdir(images_output)
+ except:
+ print(f"{images_output} already exists")
+
+ try:
+ os.mkdir(labels_output)
+ except:
+ print(f"{labels_output} already exists")
+
+ paired_list = []
+ if label_list is not None:
+ for i in range(0, len(label_list), 2):
+ if not label_list[i].isdigit():
+ print(
+ "Wrong order of input argument for pair-wising label value and its name !!!")
+ return
+ else:
+ value = label_list[i]
+ if not label_list[i+1].isdigit():
+ key = label_list[i+1]
+ ele = tuple((key, value))
+ paired_list.append(ele)
+ else:
+ print(
+ "Wrong input argument for pair-wising label value and its name !!!")
+ return
+
+ # print(new_segmentation)
+ seg_output_path = checkSegFormat(
+ base, segmentation, paired_list, check=True)
+ for j in sorted(glob.glob(os.path.join(base, images_path) + '/*' + fomat)):
+ id = os.path.basename(j).split('.')[0]
+ if id == template:
+ pass
+ else:
+ target = id
+ if id in label_lists:
+ split_and_registration(
+ template, target, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
+ else:
+ split_and_registration(
+ template, target, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
+
+ image = ants.image_read(os.path.join(
+ base, images_path, template + '.' + fomat))
+ image.to_file(os.path.join(base, images_output, template + '.nii.gz'))
+ fomat = 'nii.gz'
+ images_path = os.path.join(base, 'imagesRS/')
+ if template in label_lists:
+ split_and_registration(
+ target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
+ else:
+ split_and_registration(
+ target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
+
+ else:
+ seg_output_path = checkSegFormat(
+ base, segmentation, paired_list, check=False)
+
+ for i in sorted(glob.glob(os.path.join(base, images_path) + '/*' + fomat)):
+ id = os.path.basename(i).split('.')[0]
+ if id == template:
+ pass
+ else:
+ target = id
+ if id in label_lists:
+ split_and_registration(
+ template, target, base, images_path, seg_output_path, fomat, checked=False, has_label=True)
+ else:
+ split_and_registration(
+ template, target, base, images_path, seg_output_path, fomat, checked=False, has_label=False)
+
+ image = ants.image_read(os.path.join(
+ base, images_path, template + '.' + fomat))
+ image.to_file(os.path.join(base, images_output, template + '.nii.gz'))
+
+ images_path = os.path.join(base, 'imagesRS/')
+ fomat = 'nii.gz'
+ if template in label_lists:
+ split_and_registration(
+ target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=True)
+ else:
+ split_and_registration(
+ target, template, base, images_path, seg_output_path, fomat, checked=True, has_label=False)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/preprocessing/split_data.py b/preprocessing/split_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f4714e3dff2f864751c6071ff5d14e18c4bd9d6
--- /dev/null
+++ b/preprocessing/split_data.py
@@ -0,0 +1,200 @@
+'''
+python3 split_data.py RegisteredImageFolderPath RegisteredLabelFolderPath
+
+Given the parameter as the path to the registered images,
+function creates two folders in the base directory (same level as this script), randomly putting in
+70 percent of images into the train and 30 percent to the test
+'''
+import os
+import glob
+import random
+import shutil
+
+from typing import Tuple
+import numpy as np
+from collections import OrderedDict
+import json
+import argparse
+
+
+"""
+creates a folder at a specified folder path if it does not exists
+folder_path : relative path of the folder (from cur_dir) which needs to be created
+over_write :(default: False) if True overwrite the existing folder
+ """
+def parse_command_line():
+ print('---'*10)
+ print('Parsing Command Line Arguments')
+ parser = argparse.ArgumentParser(
+ description='pipeline for dataset split')
+ parser.add_argument('-bp', metavar='base path', type=str,
+ help="Absolute path of the base directory")
+ parser.add_argument('-ip', metavar='image path', type=str,
+ help="Relative path of the image directory")
+ parser.add_argument('-sp', metavar='segmentation path', type=str,
+ help="Relative path of the image directory")
+ parser.add_argument('-sl', metavar='segmentation information list', type=str, nargs='+',
+ help='a list of label name and corresponding value')
+ parser.add_argument('-ti', metavar='task id', type=int,
+ help='task id number')
+ parser.add_argument('-tn', metavar='task name', type=str,
+ help='task name')
+ parser.add_argument('-kf', metavar='k-fold validation', type=int, default=5,
+ help='k-fold validation')
+ argv = parser.parse_args()
+ return argv
+
+
+def make_if_dont_exist(folder_path, overwrite=False):
+
+ if os.path.exists(folder_path):
+ if not overwrite:
+ print(f'{folder_path} exists.')
+ else:
+ print(f"{folder_path} overwritten")
+ shutil.rmtree(folder_path)
+ os.makedirs(folder_path)
+ else:
+ os.makedirs(folder_path)
+ print(f"{folder_path} created!")
+
+
+def rename(location, oldname, newname):
+
+ os.rename(os.path.join(location, oldname), os.path.join(location, newname))
+
+
+def main():
+ args = parse_command_line()
+ base = args.bp
+ reg_data_path = args.ip
+ lab_data_path = args.sp
+ task_id = args.ti
+ Name = args.tn
+ k_fold = args.kf
+ seg_list = args.sl
+ base_dir = "/home/ameen"
+ #os.chdir(base_dir)
+ nnunet_dir = "nnUNet/nnunet/nnUNet_raw_data_base/nnUNet_raw_data"
+ main_dir = os.path.join(base_dir, 'nnUNet/nnunet')
+ make_if_dont_exist(os.path.join(main_dir, 'nnUNet_preprocessed'))
+ make_if_dont_exist(os.path.join(main_dir, 'nnUNet_trained_models'))
+
+ os.environ['nnUNet_raw_data_base'] = os.path.join(
+ main_dir, 'nnUNet_raw_data_base')
+ os.environ['nnUNet_preprocessed'] = os.path.join(
+ main_dir, 'nnUNet_preprocessed')
+ os.environ['RESULTS_FOLDER'] = os.path.join(
+ main_dir, 'nnUNet_trained_models')
+
+ random.seed(19)
+ cur_path = os.getcwd() # current working
+
+ image_list = glob.glob(os.path.join(base, reg_data_path) + "/*.nii.gz")
+ label_list = glob.glob(os.path.join(base, lab_data_path) + "/*.nii.gz")
+ num_images = len(image_list)
+ # compute number of data for each fold
+ num_each_fold = divmod(num_images, k_fold)[0]
+ fold_num = np.repeat(num_each_fold, k_fold)
+ num_remain = divmod(num_images, k_fold)[1]
+ count = 0
+ while num_remain > 0:
+ fold_num[count] += 1
+ count = (count+1) % 5
+ num_remain -= 1
+
+ random.shuffle(image_list)
+ piece_data = {}
+ start_point = 0
+ # select scans for each fold
+ for m in range(k_fold):
+ piece_data[f'fold_{m}'] = image_list[start_point:start_point+fold_num[m]]
+ start_point += fold_num[m]
+
+ for j in range(k_fold):
+ task_name = f"Task0{task_id}_{Name}_fold{j}" # MODIFY
+ task_id += 1
+ task_folder_name = os.path.join(base_dir, nnunet_dir, task_name)
+ train_image_dir = os.path.join(task_folder_name, 'imagesTr')
+ train_label_dir = os.path.join(task_folder_name, 'labelsTr')
+ test_dir = os.path.join(task_folder_name, 'imagesTs')
+
+ make_if_dont_exist(task_folder_name)
+ make_if_dont_exist(train_image_dir)
+ make_if_dont_exist(train_label_dir)
+ make_if_dont_exist(test_dir)
+ # Dataset Split (70 / 30):
+ num_test = fold_num[j]
+ num_train = np.sum(fold_num) - num_test
+ print("Number of training subjects: ", num_train,
+ "\nNumber of testing subjects:", num_test, "\nTotal:", num_images)
+ p = 0
+ train_images = []
+ # concat all 4 folds for training
+ while p < len(piece_data):
+ if p !=j:
+ train_images.extend(piece_data[f'fold_{p}'])
+ p+=1
+ # select one fold for testing
+ test_images = piece_data[f'fold_{j}']
+ # prepare for nnUNet training scans and labels
+ for i in range(len(train_images)):
+ filename1 = os.path.basename(train_images[i]).split(".")[0]
+ number = ''.join(filter(lambda x: x.isdigit(), filename1))
+ # put this image to the training folder
+ shutil.copy(train_images[i], train_image_dir)
+ filename = os.path.basename(train_images[i])
+ rename(train_image_dir, filename, Name + "_" + number + "_0000.nii.gz")
+
+ for label_dir in label_list:
+ if label_dir.endswith(os.path.basename(train_images[i])):
+ shutil.copy(label_dir, train_label_dir)
+ rename(train_label_dir, filename, Name + "_" + number + '.nii.gz')
+ break
+ # prepare for nnUNet testing scans
+ for i in range(len(test_images)):
+ # put this image to the test folder
+ shutil.copy(test_images[i], test_dir)
+ filename = os.path.basename(test_images[i])
+ filename1 = os.path.basename(test_images[i]).split(".")[0]
+ number = ''.join(filter(lambda x: x.isdigit(), filename1))
+ rename(test_dir, filename, Name + "_" + number + "_0000.nii.gz")
+
+ # create json file
+ json_dict = OrderedDict()
+ json_dict['name'] = task_name
+ json_dict['description'] = Name
+ json_dict['tensorImageSize'] = "4D"
+ json_dict['reference'] = "MODIFY"
+ json_dict['licence'] = "MODIFY"
+ json_dict['release'] = "0.0"
+ json_dict['modality'] = {
+ "0": "CT"
+ }
+ json_dict['labels'] = {
+ "0": "background",
+ }
+ for i in range(0, len(seg_list), 2):
+ assert(seg_list[i].isdigit() == True)
+ assert(seg_list[i + 1].isdigit() == False)
+ json_dict['labels'].update({
+ seg_list[i]: seg_list[i + 1]
+ })
+ train_ids = os.listdir(train_image_dir)
+ test_ids = os.listdir(test_dir)
+ json_dict['numTraining'] = len(train_ids)
+ json_dict['numTest'] = len(test_ids)
+ json_dict['training'] = [{'image': "./imagesTr/%s" % (i[:i.find(
+ "_0000")]+'.nii.gz'), "label": "./labelsTr/%s" % (i[:i.find("_0000")]+'.nii.gz')} for i in train_ids]
+ json_dict['test'] = ["./imagesTs/%s" %
+ (i[:i.find("_0000")]+'.nii.gz') for i in test_ids]
+
+ with open(os.path.join(task_folder_name, "dataset.json"), 'w') as f:
+ json.dump(json_dict, f, indent=4, sort_keys=True)
+
+ if os.path.exists(os.path.join(task_folder_name, 'dataset.json')):
+ print("new json file created!")
+
+
+if __name__ == '__main__':
+ main()
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..530fc20f9e18cf6472a78d6b931676854913abd9
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,158 @@
+antspyx==0.3.1
+argon2-cffi==21.3.0
+argon2-cffi-bindings==21.2.0
+asgiref==3.5.0
+asttokens==2.0.5
+attrs==21.4.0
+backcall==0.2.0
+backports.zoneinfo==0.2.1
+batchgenerators==0.23
+beautifulsoup4==4.10.0
+black==22.1.0
+bleach==4.1.0
+Brotli==1.0.9
+cachetools==5.0.0
+certifi==2021.10.8
+cffi==1.15.0
+charset-normalizer==2.0.12
+chart-studio==1.1.0
+click==8.0.3
+cycler==0.11.0
+debugpy==1.5.1
+decorator==5.1.1
+defusedxml==0.7.1
+dicom2nifti==2.3.0
+Django==4.0.2
+entrypoints==0.4
+executing==0.8.2
+filelock==3.6.0
+fonttools==4.29.1
+future==0.18.2
+gdown==4.4.0
+gevent==21.12.0
+google-api-core==2.7.1
+google-api-python-client==2.43.0
+google-auth==2.6.2
+google-auth-httplib2==0.1.0
+googleapis-common-protos==1.56.0
+greenlet==1.1.2
+httplib2==0.20.4
+idna==3.3
+imageio==2.16.0
+importlib-resources==5.4.0
+ipykernel==6.9.1
+ipython==8.0.1
+ipython-genutils==0.2.0
+ipywidgets==7.6.5
+jedi==0.18.1
+Jinja2==3.0.3
+joblib==1.1.0
+jsonschema==4.4.0
+jupyter==1.0.0
+jupyter-client==7.1.2
+jupyter-console==6.4.0
+jupyter-core==4.9.2
+jupyterlab-pygments==0.1.2
+jupyterlab-widgets==1.0.2
+kiwisolver==1.3.2
+linecache2==1.0.0
+loguru==0.6.0
+MarkupSafe==2.0.1
+matplotlib==3.5.1
+matplotlib-inline==0.1.3
+MedPy==0.4.0
+mistune==0.8.4
+multivolumefile==0.2.3
+mypy-extensions==0.4.3
+nbclient==0.5.11
+nbconvert==6.4.2
+nbformat==5.1.3
+nest-asyncio==1.5.4
+networkx==2.6.3
+nibabel==3.2.2
+-e git+https://github.com/MIC-DKFZ/nnUNet.git@b16142ac0d15e4098d9b6c9a2b828b8dc4957c2f#egg=nnunet
+notebook==6.4.8
+numpy==1.22.2
+oauth2client==4.1.3
+packaging==21.3
+pandas==1.4.1
+pandocfilters==1.5.0
+parso==0.8.3
+pathspec==0.9.0
+patsy==0.5.2
+pexpect==4.8.0
+pickleshare==0.7.5
+Pillow==9.0.1
+platformdirs==2.5.0
+plotly==5.6.0
+prometheus-client==0.13.1
+prompt-toolkit==3.0.28
+protobuf==3.20.0
+psutil==5.9.0
+ptyprocess==0.7.0
+pure-eval==0.2.2
+py7zr==0.18.1
+pyasn1==0.4.8
+pyasn1-modules==0.2.8
+pybcj==0.5.0
+pycparser==2.21
+pycryptodomex==3.14.1
+pydicom==2.2.2
+PyDrive==1.3.1
+Pygments==2.11.2
+pynrrd==0.4.2
+pyparsing==3.0.7
+pyppmd==0.17.4
+pyrsistent==0.18.1
+PySocks==1.7.1
+python-dateutil==2.8.2
+pytz==2021.3
+PyWavelets==1.2.0
+PyYAML==6.0
+pyzmq==22.3.0
+pyzstd==0.15.2
+qtconsole==5.2.2
+QtPy==2.0.1
+rclone==0.3
+requests==2.27.1
+retrying==1.3.3
+rsa==4.8
+scikit-image==0.19.1
+scikit-learn==1.0.2
+scipy==1.8.0
+Send2Trash==1.8.0
+SimpleITK==2.1.1
+six==1.16.0
+sklearn==0.0
+slicerio==0.1.3
+soupsieve==2.3.1
+sqlparse==0.4.2
+stack-data==0.2.0
+statsmodels==0.13.2
+tenacity==8.0.1
+terminado==0.13.1
+testpath==0.5.0
+texttable==1.6.4
+threadpoolctl==3.1.0
+tifffile==2022.2.9
+tomli==2.0.1
+torch==1.10.2+cu113
+torchaudio==0.10.2+cu113
+torchvision==0.11.3+cu113
+tornado==6.1
+tqdm==4.62.3
+traceback2==1.4.0
+traitlets==5.1.1
+typing_extensions==4.1.1
+unittest2==1.1.0
+unzip==1.0.0
+uritemplate==4.1.1
+urllib3==1.26.8
+wcwidth==0.2.5
+webcolors==1.11.1
+webencodings==0.5.1
+widgetsnbextension==3.5.2
+zipfile-deflate64==0.2.0
+zipp==3.7.0
+zope.event==4.5.0
+zope.interface==5.4.0