yufanh-nv commited on
Commit
d9097f2
·
verified ·
1 Parent(s): c55c320

Initial commit

Browse files
LICENSE ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Code License
2
+
3
+ This license applies to all files except the model weights in the directory.
4
+
5
+ Apache License
6
+ Version 2.0, January 2004
7
+ http://www.apache.org/licenses/
8
+
9
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
10
+
11
+ 1. Definitions.
12
+
13
+ "License" shall mean the terms and conditions for use, reproduction,
14
+ and distribution as defined by Sections 1 through 9 of this document.
15
+
16
+ "Licensor" shall mean the copyright owner or entity authorized by
17
+ the copyright owner that is granting the License.
18
+
19
+ "Legal Entity" shall mean the union of the acting entity and all
20
+ other entities that control, are controlled by, or are under common
21
+ control with that entity. For the purposes of this definition,
22
+ "control" means (i) the power, direct or indirect, to cause the
23
+ direction or management of such entity, whether by contract or
24
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
25
+ outstanding shares, or (iii) beneficial ownership of such entity.
26
+
27
+ "You" (or "Your") shall mean an individual or Legal Entity
28
+ exercising permissions granted by this License.
29
+
30
+ "Source" form shall mean the preferred form for making modifications,
31
+ including but not limited to software source code, documentation
32
+ source, and configuration files.
33
+
34
+ "Object" form shall mean any form resulting from mechanical
35
+ transformation or translation of a Source form, including but
36
+ not limited to compiled object code, generated documentation,
37
+ and conversions to other media types.
38
+
39
+ "Work" shall mean the work of authorship, whether in Source or
40
+ Object form, made available under the License, as indicated by a
41
+ copyright notice that is included in or attached to the work
42
+ (an example is provided in the Appendix below).
43
+
44
+ "Derivative Works" shall mean any work, whether in Source or Object
45
+ form, that is based on (or derived from) the Work and for which the
46
+ editorial revisions, annotations, elaborations, or other modifications
47
+ represent, as a whole, an original work of authorship. For the purposes
48
+ of this License, Derivative Works shall not include works that remain
49
+ separable from, or merely link (or bind by name) to the interfaces of,
50
+ the Work and Derivative Works thereof.
51
+
52
+ "Contribution" shall mean any work of authorship, including
53
+ the original version of the Work and any modifications or additions
54
+ to that Work or Derivative Works thereof, that is intentionally
55
+ submitted to Licensor for inclusion in the Work by the copyright owner
56
+ or by an individual or Legal Entity authorized to submit on behalf of
57
+ the copyright owner. For the purposes of this definition, "submitted"
58
+ means any form of electronic, verbal, or written communication sent
59
+ to the Licensor or its representatives, including but not limited to
60
+ communication on electronic mailing lists, source code control systems,
61
+ and issue tracking systems that are managed by, or on behalf of, the
62
+ Licensor for the purpose of discussing and improving the Work, but
63
+ excluding communication that is conspicuously marked or otherwise
64
+ designated in writing by the copyright owner as "Not a Contribution."
65
+
66
+ "Contributor" shall mean Licensor and any individual or Legal Entity
67
+ on behalf of whom a Contribution has been received by Licensor and
68
+ subsequently incorporated within the Work.
69
+
70
+ 2. Grant of Copyright License. Subject to the terms and conditions of
71
+ this License, each Contributor hereby grants to You a perpetual,
72
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
73
+ copyright license to reproduce, prepare Derivative Works of,
74
+ publicly display, publicly perform, sublicense, and distribute the
75
+ Work and such Derivative Works in Source or Object form.
76
+
77
+ 3. Grant of Patent License. Subject to the terms and conditions of
78
+ this License, each Contributor hereby grants to You a perpetual,
79
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
80
+ (except as stated in this section) patent license to make, have made,
81
+ use, offer to sell, sell, import, and otherwise transfer the Work,
82
+ where such license applies only to those patent claims licensable
83
+ by such Contributor that are necessarily infringed by their
84
+ Contribution(s) alone or by combination of their Contribution(s)
85
+ with the Work to which such Contribution(s) was submitted. If You
86
+ institute patent litigation against any entity (including a
87
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
88
+ or a Contribution incorporated within the Work constitutes direct
89
+ or contributory patent infringement, then any patent licenses
90
+ granted to You under this License for that Work shall terminate
91
+ as of the date such litigation is filed.
92
+
93
+ 4. Redistribution. You may reproduce and distribute copies of the
94
+ Work or Derivative Works thereof in any medium, with or without
95
+ modifications, and in Source or Object form, provided that You
96
+ meet the following conditions:
97
+
98
+ (a) You must give any other recipients of the Work or
99
+ Derivative Works a copy of this License; and
100
+
101
+ (b) You must cause any modified files to carry prominent notices
102
+ stating that You changed the files; and
103
+
104
+ (c) You must retain, in the Source form of any Derivative Works
105
+ that You distribute, all copyright, patent, trademark, and
106
+ attribution notices from the Source form of the Work,
107
+ excluding those notices that do not pertain to any part of
108
+ the Derivative Works; and
109
+
110
+ (d) If the Work includes a "NOTICE" text file as part of its
111
+ distribution, then any Derivative Works that You distribute must
112
+ include a readable copy of the attribution notices contained
113
+ within such NOTICE file, excluding those notices that do not
114
+ pertain to any part of the Derivative Works, in at least one
115
+ of the following places: within a NOTICE text file distributed
116
+ as part of the Derivative Works; within the Source form or
117
+ documentation, if provided along with the Derivative Works; or,
118
+ within a display generated by the Derivative Works, if and
119
+ wherever such third-party notices normally appear. The contents
120
+ of the NOTICE file are for informational purposes only and
121
+ do not modify the License. You may add Your own attribution
122
+ notices within Derivative Works that You distribute, alongside
123
+ or as an addendum to the NOTICE text from the Work, provided
124
+ that such additional attribution notices cannot be construed
125
+ as modifying the License.
126
+
127
+ You may add Your own copyright statement to Your modifications and
128
+ may provide additional or different license terms and conditions
129
+ for use, reproduction, or distribution of Your modifications, or
130
+ for any such Derivative Works as a whole, provided Your use,
131
+ reproduction, and distribution of the Work otherwise complies with
132
+ the conditions stated in this License.
133
+
134
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
135
+ any Contribution intentionally submitted for inclusion in the Work
136
+ by You to the Licensor shall be under the terms and conditions of
137
+ this License, without any additional terms or conditions.
138
+ Notwithstanding the above, nothing herein shall supersede or modify
139
+ the terms of any separate license agreement you may have executed
140
+ with Licensor regarding such Contributions.
141
+
142
+ 6. Trademarks. This License does not grant permission to use the trade
143
+ names, trademarks, service marks, or product names of the Licensor,
144
+ except as required for reasonable and customary use in describing the
145
+ origin of the Work and reproducing the content of the NOTICE file.
146
+
147
+ 7. Disclaimer of Warranty. Unless required by applicable law or
148
+ agreed to in writing, Licensor provides the Work (and each
149
+ Contributor provides its Contributions) on an "AS IS" BASIS,
150
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
151
+ implied, including, without limitation, any warranties or conditions
152
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
153
+ PARTICULAR PURPOSE. You are solely responsible for determining the
154
+ appropriateness of using or redistributing the Work and assume any
155
+ risks associated with Your exercise of permissions under this License.
156
+
157
+ 8. Limitation of Liability. In no event and under no legal theory,
158
+ whether in tort (including negligence), contract, or otherwise,
159
+ unless required by applicable law (such as deliberate and grossly
160
+ negligent acts) or agreed to in writing, shall any Contributor be
161
+ liable to You for damages, including any direct, indirect, special,
162
+ incidental, or consequential damages of any character arising as a
163
+ result of this License or out of the use or inability to use the
164
+ Work (including but not limited to damages for loss of goodwill,
165
+ work stoppage, computer failure or malfunction, or any and all
166
+ other commercial damages or losses), even if such Contributor
167
+ has been advised of the possibility of such damages.
168
+
169
+ 9. Accepting Warranty or Additional Liability. While redistributing
170
+ the Work or Derivative Works thereof, You may choose to offer,
171
+ and charge a fee for, acceptance of support, warranty, indemnity,
172
+ or other liability obligations and/or rights consistent with this
173
+ License. However, in accepting such obligations, You may act only
174
+ on Your own behalf and on Your sole responsibility, not on behalf
175
+ of any other Contributor, and only if You agree to indemnify,
176
+ defend, and hold each Contributor harmless for any liability
177
+ incurred by, or claims asserted against, such Contributor by reason
178
+ of your accepting any such warranty or additional liability.
179
+
180
+ END OF TERMS AND CONDITIONS
181
+
182
+ APPENDIX: How to apply the Apache License to your work.
183
+
184
+ To apply the Apache License to your work, attach the following
185
+ boilerplate notice, with the fields enclosed by brackets "[]"
186
+ replaced with your own identifying information. (Don't include
187
+ the brackets!) The text should be enclosed in the appropriate
188
+ comment syntax for the file format. We also recommend that a
189
+ file or class name and description of purpose be included on the
190
+ same "printed page" as the copyright notice for easier
191
+ identification within third-party archives.
192
+
193
+ Copyright [yyyy] [name of copyright owner]
194
+
195
+ Licensed under the Apache License, Version 2.0 (the "License");
196
+ you may not use this file except in compliance with the License.
197
+ You may obtain a copy of the License at
198
+
199
+ http://www.apache.org/licenses/LICENSE-2.0
200
+
201
+ Unless required by applicable law or agreed to in writing, software
202
+ distributed under the License is distributed on an "AS IS" BASIS,
203
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
204
+ See the License for the specific language governing permissions and
205
+ limitations under the License.
206
+
207
+ ------------------------------------------------------------------------------
208
+
209
+ Model Weights License
210
+
211
+ This license applies to model weights in the directory.
212
+
213
+ NVIDIA License
214
+
215
+ 1. Definitions
216
+
217
+ “Licensor” means any person or entity that distributes its Work.
218
+ “Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license.
219
+ The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work.
220
+ Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license.
221
+
222
+ 2. License Grant
223
+
224
+ 2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form.
225
+
226
+ 3. Limitations
227
+
228
+ 3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work.
229
+
230
+ 3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works, and (b) you identify the specific derivative works that are subject to Your Terms. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself.
231
+
232
+ 3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. Notwithstanding the foregoing, NVIDIA Corporation and its affiliates may use the Work and any derivative works commercially. As used herein, “non-commercially” means for research or evaluation purposes only.
233
+
234
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately.
235
+
236
+ 3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license.
237
+
238
+ 3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately.
239
+
240
+ 4. Disclaimer of Warranty.
241
+
242
+ THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
243
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE.
244
+
245
+ 5. Limitation of Liability.
246
+
247
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES.
README.md ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NV-Segment-CTMR Overview
2
+
3
+ ## Description:
4
+ NV-Segment-CTMR is a specialized foundation model for 3D medical image segmentation that excels at accurate, adaptable, automatic segmentation across anatomies and modalities, including computed tomography (CT) and magnetic resonance (MR) imaging. NV-Segment-CTMR adapts to varying conditions and anatomical regions, enabling comprehensive automated annotation workflows.
5
+
6
+ At the core of NV-Segment-CTMR are two automated workflows. Segment Everything enables whole-body exploration, which is crucial for understanding complex diseases affecting multiple organs and for holistic treatment planning. Segment by Class provides detailed sectional views based on specific classes, supporting targeted disease analysis or organ mapping, such as tumor identification in critical organs.
7
+
8
+ This model is for research and development only.
9
+
10
+
11
+ ## Run pipeline:
12
+ For running the pipeline, NV-Segment-CTMR requires at least one prompt for segmentation. It supports label prompt, which is the index of the class for automatic segmentation. NV-Segment-CTMR does not support point based interactive segmentation. For interactive model, please refer to [VISTA3D](https://github.com/Project-MONAI/VISTA/tree/main/vista3ds)
13
+
14
+ Here is a code snippet to showcase how to execute inference with this model.
15
+ ```python
16
+ import os
17
+ import tempfile
18
+
19
+ import torch
20
+ from hugging_face_pipeline import HuggingFacePipelineHelper
21
+
22
+
23
+ FILE_PATH = os.path.dirname(__file__)
24
+ with tempfile.TemporaryDirectory() as tmp_dir:
25
+ output_dir = os.path.join(tmp_dir, "output_dir")
26
+ pipeline_helper = HuggingFacePipelineHelper("vista3d")
27
+ pipeline = pipeline_helper.init_pipeline(
28
+ os.path.join(FILE_PATH, "vista3d_pretrained_model"),
29
+ device=torch.device("cuda:0"),
30
+ )
31
+ inputs = [
32
+ {
33
+ "image": "/data/Task09_Spleen/imagesTs/spleen_1.nii.gz",
34
+ "label_prompt": [3],
35
+ },
36
+ {
37
+ "image": "/data/Task09_Spleen/imagesTs/spleen_11.nii.gz",
38
+ "modality": 'CT_BODY',
39
+ },
40
+ {
41
+ "image": "/data/Task09_Spleen/imagesTs/spleen_11.nii.gz",
42
+ "modality": 'MRI_BODY'
43
+ },
44
+ ]
45
+ pipeline(inputs, output_dir=output_dir)
46
+
47
+ ```
48
+ The inputs defines the image to segment and the prompt for segmentation.
49
+ ```python
50
+ inputs = {'image': '/data/Task09_Spleen/imagesTs/spleen_15.nii.gz', 'label_prompt':[1]}
51
+ ```
52
+ - The inputs must include the key `image` which contain the absolute path to the nii image file, and includes prompt keys of `label_prompt`.
53
+ - The `label_prompt` is a list of length `B`, which can perform `B` foreground objects segmentation, e.g. `[2,3,4,5]`. The full list of label definition is in `metadata.json`.
54
+ - If no prompt is provided, user can use `modality` to use predefined class indices. Supported modality includes `CT_BODY`, `MRI_BODY`, `MRI_BRAIN`.
55
+ ```
56
+ Note: For brain structure segmentation, current model only support standard brain T1 images. The brain T1 images must be preprocessed with skull stripping and normalization. Follow https://github.com/junyuchen245/MIR/tree/main/tutorials/brain_MRI_preprocessing to process the brain images
57
+ ```
58
+
59
+ ### License/Terms of Use:
60
+ NVIDIA OneWay Non-Commercial License for academic research purposes
61
+
62
+ ### Deployment Geography:
63
+ Global
64
+
65
+ ### Use Case:
66
+ Medical researchers, AI developers, and healthcare institutions are expected to use this system to perform automated medical image segmentation, conduct multi-organ analysis, and accelerate annotation workflows in research applications.
67
+
68
+ ### Release Date:
69
+ Huggingface: 10/27/2025 via https://huggingface.co/NVIDIA
70
+
71
+ ## Reference(s):
72
+ [1] He, Yufan, et al. "VISTA3D: A Unified Segmentation Foundation Model For 3D Medical Imaging." arXiv preprint arXiv:2406.05285. 2024. https://arxiv.org/abs/2406.05285
73
+
74
+ ## Model Architecture:
75
+ **Architecture Type:** Transformer
76
+ **Network Architecture:** SAM-like architecture for 3D medical imaging segmentation
77
+
78
+ This model was developed from scratch using MONAI components.
79
+ **Number of model parameters:** 218M
80
+
81
+ ## Input:
82
+ **Input Type(s):** Image
83
+ **Input Format(s):** Neuroimaging Informatics Technology Initiative (NIfTI)
84
+ **Input Parameters:** Three-Dimensional (3D)
85
+ **Other Properties Related to Input:** Supports both computed tomography (CT) and magnetic resonance (MR) imaging modalities. It also supports optional class information for targeted segmentation workflows.
86
+
87
+ ### Input Modalities:
88
+ - **CT Images:** 3D computed tomography volumes
89
+ - **MR Images:** 3D magnetic resonance volumes
90
+ - **Class Selection:** Optional class indices for targeted segmentation workflows
91
+
92
+ ## Output:
93
+ **Output Type(s):** Image
94
+ **Output Format:** Neuroimaging Informatics Technology Initiative (NIfTI)
95
+ **Output Parameters:** Three-Dimensional (3D)
96
+ **Other Properties Related to Output:** Segmentation masks with up to 345+ anatomical classes, providing comprehensive organ and tissue delineation for medical imaging analysis.
97
+
98
+ Our AI models are designed and/or optimized to run on NVIDIA GPU-accelerated systems. By leveraging NVIDIA's hardware (GPU cores) and software frameworks (CUDA libraries), the model achieves faster training and inference times compared to CPU-only solutions.
99
+
100
+ ## Software Integration:
101
+ **Runtime Engine(s):**
102
+ * MONAI Core v.1.5.0
103
+
104
+ **Supported Hardware Microarchitecture Compatibility:**
105
+ * NVIDIA Ampere
106
+ * NVIDIA Hopper
107
+
108
+ **Supported Operating System(s):**
109
+ * Linux
110
+
111
+ The integration of foundation and fine-tuned models into AI systems requires additional testing using use-case-specific data to ensure safe and effective deployment. Following the V-model methodology, iterative testing and validation at both unit and system levels are essential to mitigate risks, meet technical and functional requirements, and ensure compliance with safety and ethical standards before deployment.
112
+
113
+ ## Model Version(s):
114
+ 0.1 - Initial release version for 3D medical imaging segmentation with multi-modality support
115
+
116
+ ## Training, Testing, and Evaluation Datasets:
117
+
118
+ ### Dataset Overview:
119
+ **Total Size:** ~31k
120
+ **Total Number of Datasets:** 32 datasets
121
+
122
+ Public datasets from multiple scanner types were processed to create standardized 3D medical imaging volumes with expert-validated anatomical segmentation masks across diverse anatomical regions and pathological conditions. The data processing pipeline ensured consistent voxel spacing, standardized orientations, and validated anatomical segmentations.
123
+
124
+ ## Training Dataset:
125
+ **Data Modality:**
126
+ * Image
127
+
128
+ **Image Training Data Size:**
129
+ * Less than a Million Images
130
+
131
+ **Data Collection Method by dataset:**
132
+ * Hybrid: Human, Automatic/Sensors
133
+
134
+ **Labeling Method by dataset:**
135
+ * Hybrid: Human, Automatic/Sensors
136
+
137
+ ## Testing Dataset:
138
+ **Data Collection Method by dataset:**
139
+ * Hybrid: Human, Automatic/Sensors
140
+
141
+ **Labeling Method by dataset:**
142
+ * Hybrid: Human, Automatic/Sensors
143
+
144
+ ## Evaluation Dataset:
145
+ **Data Collection Method by dataset:**
146
+ * Hybrid: Human, Automatic/Sensors
147
+
148
+ **Labeling Method by dataset:**
149
+ * Hybrid: Human, Automatic/Sensors
150
+
151
+ ## Inference:
152
+ **Acceleration Engine:** PyTorch
153
+ **Test Hardware:**
154
+ * A100
155
+ * H100
156
+
157
+ ## Additional Information:
158
+ ### Available Anatomical Classes (345+ total):
159
+ NV-Segment-CTMR supports comprehensive anatomical segmentation with the following categories:
160
+
161
+ **Core Organs and Systems:**
162
+ - **Abdominal organs:** liver (1), kidney (2), spleen (3), pancreas (4), gallbladder (10), stomach (12), bladder (15), colon (62)
163
+ - **Cardiovascular:** heart (115), aorta (6), inferior vena cava (7), superior vena cava (125), portal and splenic veins (17)
164
+ - **Respiratory:** lung (20), trachea (57), airway (132), individual lung lobes (28-32)
165
+ - **Neurological:** brain (22), spinal cord (121), complete brain structures (214-345)
166
+
167
+ **Skeletal System:**
168
+ - **Spine:** Complete vertebral column from C1-S1 (33-56, 127)
169
+ - **Thoracic:** Bilateral ribs 1-12 (63-86), sternum (122), costal cartilages (114)
170
+ - **Appendicular:** Bilateral long bones, joints, and extremities (87-96)
171
+
172
+ **Detailed Brain Segmentation:**
173
+ Comprehensive brain parcellation including ventricles, cortical regions, subcortical structures, and specialized brain areas (214-345) based on neuroanatomical atlases.
174
+
175
+ **Pathological Structures:**
176
+ - **Tumors:** lung tumor (23), pancreatic tumor (24), hepatic tumor (26), brain tumor (176)
177
+ - **Cancer:** colon cancer primaries (27)
178
+ - **Lesions:** bone lesion (128)
179
+ - **Cysts:** bilateral kidney cysts (116-117)
180
+ - **Note:** We recommend the `NV-Segment-CT` model for better tumor performance.
181
+
182
+ **Specialized Regions:**
183
+ - **Head and neck:** detailed facial structures, sensory organs, and cranial anatomy (172-213)
184
+ - **Cardiac:** heart chambers, major vessels, and cardiac-specific structures (108, 149-155)
185
+ - **Reproductive:** prostate zones (118, 147-148), uterocervix (161), gonads (160)
186
+
187
+ *Complete numerical mapping and deprecated classes available in model documentation.*
188
+
189
+ ## Ethical Considerations:
190
+ NVIDIA believes Trustworthy AI is a shared responsibility and we have established policies and practices to enable development for a wide array of AI applications. When downloaded or used in accordance with our terms of service, developers should work with their internal model team to ensure this model meets requirements for the relevant industry and use case and addresses unforeseen product misuse. Please make sure you have proper rights and permissions for all input image and video content; if image or video includes people, personal health information, or intellectual property, the image or video generated will not blur or maintain proportions of image subjects included.
191
+
192
+ Please report model quality, risk, security vulnerabilities or concerns [here](https://www.nvidia.com/en-us/support/submit-security-vulnerability/).
__init__.py ADDED
File without changes
data_license.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ Third Party Licenses
2
+ -----------------------------------------------------------------------
3
+
4
+ /*********************************************************************/
5
+ i. Medical Segmentation Decathlon
6
+ http://medicaldecathlon.com/
hugging_face_pipeline.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import pipeline
2
+ from vista3d_config import VISTA3DConfig
3
+ from vista3d_model import VISTA3DModel, register_my_model
4
+ from vista3d_pipeline import VISTA3DPipeline, register_simple_pipeline
5
+
6
+
7
+ class HuggingFacePipelineHelper:
8
+
9
+ def __init__(self, pipeline_name: str = "vista3d"):
10
+ self.pipeline_name = pipeline_name
11
+
12
+ def __model_register(self):
13
+ register_my_model()
14
+
15
+ def __pipeline_register(self):
16
+ register_simple_pipeline()
17
+
18
+ def get_pipeline(self):
19
+ self.__model_register()
20
+ self.__pipeline_register()
21
+ return pipeline(self.pipeline_name)
22
+
23
+ def _update_config(self, config, config_dict):
24
+ if config_dict:
25
+ for key in config_dict:
26
+ if hasattr(config, key) and getattr(config, key) != config_dict[key]:
27
+ setattr(config, key, config_dict[key])
28
+ return config
29
+
30
+ def init_pipeline(self, pretrained_model_name_or_path: str, **kwargs):
31
+ config = VISTA3DConfig()
32
+ config_dict = kwargs.pop("config_dict", None)
33
+ self._update_config(config, config_dict)
34
+ model = VISTA3DModel(config)
35
+ model = model.from_pretrained(
36
+ pretrained_model_name_or_path=pretrained_model_name_or_path
37
+ )
38
+ return VISTA3DPipeline(model, **kwargs)
metadata.json ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20240725.json",
3
+ "version": "0.0.1",
4
+ "changelog": {
5
+ "0.0.1": "NV-Segment-CTMR initial commit"
6
+ },
7
+ "monai_version": "1.4.0",
8
+ "pytorch_version": "2.4.0",
9
+ "numpy_version": "1.24.4",
10
+ "required_packages_version": {
11
+ "matplotlib": "3.9.1",
12
+ "einops": "0.7.0",
13
+ "scikit-image": "0.23.2",
14
+ "nibabel": "5.2.1",
15
+ "pytorch-ignite": "0.4.11",
16
+ "cucim-cu12": "24.6.0",
17
+ "mlflow": "2.17.2",
18
+ "tensorboard": "2.17.0"
19
+ },
20
+ "supported_apps": {
21
+ "vista3d-nim": ""
22
+ },
23
+ "name": "VISTA-3D: Versatile Imaging SegmenTation and Annotation",
24
+ "task": "Multi-organ Segmentation in CT Scans with Zero-shot Learning",
25
+ "description": "A 3D segmentation model that processes 128x128x128 pixel patches from CT scans to identify and delineate over 130 anatomical structures. The model employs zero-shot learning capabilities to adapt to new anatomical targets without retraining, supporting comprehensive volumetric analysis of organs, bones, muscles, and pathological findings.",
26
+ "authors": "MONAI team",
27
+ "copyright": "Copyright (c) MONAI Consortium",
28
+ "data_source": "Task09_Spleen.tar from http://medicaldecathlon.com/",
29
+ "data_type": "nibabel",
30
+ "image_classes": "1 channel data, intensity scaled to [0, 1]",
31
+ "label_classes": "single channel data",
32
+ "pred_classes": "2 channels OneHot data",
33
+ "intended_use": "This is an example, not to be used for diagnostic purposes",
34
+ "references": [],
35
+ "network_data_format": {
36
+ "inputs": {
37
+ "image": {
38
+ "type": "image",
39
+ "format": "hounsfield",
40
+ "modality": ["CT", "MRI"],
41
+ "num_channels": 1,
42
+ "spatial_shape": [
43
+ 192,
44
+ 192,
45
+ 128
46
+ ],
47
+ "dtype": "float32",
48
+ "value_range": [
49
+ 0,
50
+ 1
51
+ ],
52
+ "is_patch_data": true,
53
+ "channel_def": {
54
+ "0": "image"
55
+ }
56
+ }
57
+ },
58
+ "outputs": {
59
+ "pred": {
60
+ "type": "image",
61
+ "format": "segmentation",
62
+ "num_channels": 1,
63
+ "spatial_shape": [
64
+ 192,
65
+ 192,
66
+ 128
67
+ ],
68
+ "dtype": "float32",
69
+ "value_range": [
70
+ 0,
71
+ 1
72
+ ],
73
+ "is_patch_data": true,
74
+ "channel_def": {
75
+ "0": "background",
76
+ "1": "liver",
77
+ "2": "kidney",
78
+ "3": "spleen",
79
+ "4": "pancreas",
80
+ "5": "right kidney",
81
+ "6": "aorta",
82
+ "7": "inferior vena cava",
83
+ "8": "right adrenal gland",
84
+ "9": "left adrenal gland",
85
+ "10": "gallbladder",
86
+ "11": "esophagus",
87
+ "12": "stomach",
88
+ "13": "duodenum",
89
+ "14": "left kidney",
90
+ "15": "bladder",
91
+ "16": "prostate or uterus (deprecated)",
92
+ "17": "portal vein and splenic vein",
93
+ "18": "rectum",
94
+ "19": "small bowel",
95
+ "20": "lung",
96
+ "21": "bone",
97
+ "22": "brain",
98
+ "23": "lung tumor",
99
+ "24": "pancreatic tumor",
100
+ "25": "hepatic vessel",
101
+ "26": "hepatic tumor",
102
+ "27": "colon cancer primaries",
103
+ "28": "left lung upper lobe",
104
+ "29": "left lung lower lobe",
105
+ "30": "right lung upper lobe",
106
+ "31": "right lung middle lobe",
107
+ "32": "right lung lower lobe",
108
+ "33": "vertebrae L5",
109
+ "34": "vertebrae L4",
110
+ "35": "vertebrae L3",
111
+ "36": "vertebrae L2",
112
+ "37": "vertebrae L1",
113
+ "38": "vertebrae T12",
114
+ "39": "vertebrae T11",
115
+ "40": "vertebrae T10",
116
+ "41": "vertebrae T9",
117
+ "42": "vertebrae T8",
118
+ "43": "vertebrae T7",
119
+ "44": "vertebrae T6",
120
+ "45": "vertebrae T5",
121
+ "46": "vertebrae T4",
122
+ "47": "vertebrae T3",
123
+ "48": "vertebrae T2",
124
+ "49": "vertebrae T1",
125
+ "50": "vertebrae C7",
126
+ "51": "vertebrae C6",
127
+ "52": "vertebrae C5",
128
+ "53": "vertebrae C4",
129
+ "54": "vertebrae C3",
130
+ "55": "vertebrae C2",
131
+ "56": "vertebrae C1",
132
+ "57": "trachea",
133
+ "58": "left iliac artery",
134
+ "59": "right iliac artery",
135
+ "60": "left iliac vena",
136
+ "61": "right iliac vena",
137
+ "62": "colon",
138
+ "63": "left rib 1",
139
+ "64": "left rib 2",
140
+ "65": "left rib 3",
141
+ "66": "left rib 4",
142
+ "67": "left rib 5",
143
+ "68": "left rib 6",
144
+ "69": "left rib 7",
145
+ "70": "left rib 8",
146
+ "71": "left rib 9",
147
+ "72": "left rib 10",
148
+ "73": "left rib 11",
149
+ "74": "left rib 12",
150
+ "75": "right rib 1",
151
+ "76": "right rib 2",
152
+ "77": "right rib 3",
153
+ "78": "right rib 4",
154
+ "79": "right rib 5",
155
+ "80": "right rib 6",
156
+ "81": "right rib 7",
157
+ "82": "right rib 8",
158
+ "83": "right rib 9",
159
+ "84": "right rib 10",
160
+ "85": "right rib 11",
161
+ "86": "right rib 12",
162
+ "87": "left humerus",
163
+ "88": "right humerus",
164
+ "89": "left scapula",
165
+ "90": "right scapula",
166
+ "91": "left clavicula",
167
+ "92": "right clavicula",
168
+ "93": "left femur",
169
+ "94": "right femur",
170
+ "95": "left hip",
171
+ "96": "right hip",
172
+ "97": "sacrum",
173
+ "98": "left gluteus maximus",
174
+ "99": "right gluteus maximus",
175
+ "100": "left gluteus medius",
176
+ "101": "right gluteus medius",
177
+ "102": "left gluteus minimus",
178
+ "103": "right gluteus minimus",
179
+ "104": "left autochthon",
180
+ "105": "right autochthon",
181
+ "106": "left iliopsoas",
182
+ "107": "right iliopsoas",
183
+ "108": "left atrial appendage",
184
+ "109": "brachiocephalic trunk",
185
+ "110": "left brachiocephalic vein",
186
+ "111": "right brachiocephalic vein",
187
+ "112": "left common carotid artery",
188
+ "113": "right common carotid artery",
189
+ "114": "costal cartilages",
190
+ "115": "heart",
191
+ "116": "left kidney cyst",
192
+ "117": "right kidney cyst",
193
+ "118": "prostate",
194
+ "119": "pulmonary vein",
195
+ "120": "skull",
196
+ "121": "spinal cord",
197
+ "122": "sternum",
198
+ "123": "left subclavian artery",
199
+ "124": "right subclavian artery",
200
+ "125": "superior vena cava",
201
+ "126": "thyroid gland",
202
+ "127": "vertebrae S1",
203
+ "128": "bone lesion",
204
+ "129": "kidney mass (deprecated)",
205
+ "130": "liver tumor (deprecated)",
206
+ "131": "vertebrae L6 (deprecated)",
207
+ "132": "airway",
208
+ "133": "fibula (deprecated)",
209
+ "134": "intervertebral discs",
210
+ "135": "left lung",
211
+ "136": "right lung",
212
+ "137": "left quadriceps femoris (deprecated)",
213
+ "138": "right quadriceps femoris (deprecated)",
214
+ "139": "left sartorius (deprecated)",
215
+ "140": "right sartorius (deprecated)",
216
+ "141": "left thigh medial compartment (deprecated)",
217
+ "142": "right thigh medial compartment (deprecated)",
218
+ "143": "left thigh posterior compartment (deprecated)",
219
+ "144": "right thigh posterior compartment (deprecated)",
220
+ "145": "tibia (deprecated)",
221
+ "146": "vertebrae",
222
+ "147": "prostate transitional zone",
223
+ "148": "prostate peripheral zone",
224
+ "149": "left atrium",
225
+ "150": "white matter hyperintensity",
226
+ "151": "left ventricle",
227
+ "152": "right ventricle",
228
+ "153": "right atrium",
229
+ "154": "left ventricle myocardium",
230
+ "155": "ascending aorta (deprecated)",
231
+ "156": "muscles",
232
+ "157": "fat",
233
+ "158": "abdominal tissue",
234
+ "159": "mediastinal tissue",
235
+ "160": "gonads",
236
+ "161": "uterocervix",
237
+ "162": "uterus (deprecated)",
238
+ "163": "breast left",
239
+ "164": "breast right",
240
+ "165": "thyroid left",
241
+ "166": "thyroid right",
242
+ "167": "thymus",
243
+ "168": "skin",
244
+ "169": "heart tissue",
245
+ "170": "celiac trunk",
246
+ "171": "pulmonary artery",
247
+ "172": "cheek left",
248
+ "173": "cheek right",
249
+ "174": "eyeball left",
250
+ "175": "eyeball right",
251
+ "176": "brain tumor",
252
+ "177": "chiasm",
253
+ "178": "left temporal lobe",
254
+ "179": "right temporal lobe",
255
+ "180": "left eye",
256
+ "181": "right eye",
257
+ "182": "left lens",
258
+ "183": "right lens",
259
+ "184": "left optic nerve",
260
+ "185": "right optic nerve",
261
+ "186": "left middle ear",
262
+ "187": "right middle ear",
263
+ "188": "left internal auditory canal",
264
+ "189": "right internal auditory canal",
265
+ "190": "left tympanic cavity",
266
+ "191": "right tympanic cavity",
267
+ "192": "left vestibular semicircular canals",
268
+ "193": "right vestibular semicircular canals",
269
+ "194": "left cochlea",
270
+ "195": "right cochlea",
271
+ "196": "left ethmoid bone",
272
+ "197": "right ethmoid bone",
273
+ "198": "pituitary",
274
+ "199": "oral cavity",
275
+ "200": "left mandible",
276
+ "201": "right mandible",
277
+ "202": "left submandibular",
278
+ "203": "right submandibular",
279
+ "204": "left parotid",
280
+ "205": "right parotid",
281
+ "206": "left mastoid",
282
+ "207": "right mastoid",
283
+ "208": "left temporomandibular joint",
284
+ "209": "right temporomandibular joint",
285
+ "210": "larynx",
286
+ "211": "larynx glottic",
287
+ "212": "larynx supraglot",
288
+ "213": "pharynxConst",
289
+ "214": "3rd-Ventricle",
290
+ "215": "4th-Ventricle",
291
+ "216": "Right-Accumbens-Area",
292
+ "217": "Left-Accumbens-Area",
293
+ "218": "Right-Amygdala",
294
+ "219": "Left-Amygdala",
295
+ "220": "Brain-Stem",
296
+ "221": "Right-Caudate",
297
+ "222": "Left-Caudate",
298
+ "223": "Right-Cerebellum-Exterior",
299
+ "224": "Left-Cerebellum-Exterior",
300
+ "225": "Right-Cerebellum-White-Matter",
301
+ "226": "Left-Cerebellum-White-Matter",
302
+ "227": "Right-Cerebral-White-Matter",
303
+ "228": "Left-Cerebral-White-Matter",
304
+ "229": "Right-Hippocampus",
305
+ "230": "Left-Hippocampus",
306
+ "231": "Right-Inf-Lat-Vent",
307
+ "232": "Left-Inf-Lat-Vent",
308
+ "233": "Right-Lateral-Ventricle",
309
+ "234": "Left-Lateral-Ventricle",
310
+ "235": "Right-Pallidum",
311
+ "236": "Left-Pallidum",
312
+ "237": "Right-Putamen",
313
+ "238": "Left-Putamen",
314
+ "239": "Right-Thalamus-Proper",
315
+ "240": "Left-Thalamus-Proper",
316
+ "241": "Right-Ventral-DC",
317
+ "242": "Left-Ventral-DC",
318
+ "243": "Cerebellar-Vermal-Lobules-I-V",
319
+ "244": "Cerebellar-Vermal-Lobules-VI-VII",
320
+ "245": "Cerebellar-Vermal-Lobules-VIII-X",
321
+ "246": "Left-Basal-Forebrain",
322
+ "247": "Right-Basal-Forebrain",
323
+ "248": "Right-ACgG--anterior-cingulate-gyrus",
324
+ "249": "Left-ACgG--anterior-cingulate-gyrus",
325
+ "250": "Right-AIns--anterior-insula",
326
+ "251": "Left-AIns--anterior-insula",
327
+ "252": "Right-AOrG--anterior-orbital-gyrus",
328
+ "253": "Left-AOrG--anterior-orbital-gyrus",
329
+ "254": "Right-AnG---angular-gyrus",
330
+ "255": "Left-AnG---angular-gyrus",
331
+ "256": "Right-Calc--calcarine-cortex",
332
+ "257": "Left-Calc--calcarine-cortex",
333
+ "258": "Right-CO----central-operculum",
334
+ "259": "Left-CO----central-operculum",
335
+ "260": "Right-Cun---cuneus",
336
+ "261": "Left-Cun---cuneus",
337
+ "262": "Right-Ent---entorhinal-area",
338
+ "263": "Left-Ent---entorhinal-area",
339
+ "264": "Right-FO----frontal-operculum",
340
+ "265": "Left-FO----frontal-operculum",
341
+ "266": "Right-FRP---frontal-pole",
342
+ "267": "Left-FRP---frontal-pole",
343
+ "268": "Right-FuG---fusiform-gyrus",
344
+ "269": "Left-FuG---fusiform-gyrus",
345
+ "270": "Right-GRe---gyrus-rectus",
346
+ "271": "Left-GRe---gyrus-rectus",
347
+ "272": "Right-IOG---inferior-occipital-gyrus",
348
+ "273": "Left-IOG---inferior-occipital-gyrus",
349
+ "274": "Right-ITG---inferior-temporal-gyrus",
350
+ "275": "Left-ITG---inferior-temporal-gyrus",
351
+ "276": "Right-LiG---lingual-gyrus",
352
+ "277": "Left-LiG---lingual-gyrus",
353
+ "278": "Right-LOrG--lateral-orbital-gyrus",
354
+ "279": "Left-LOrG--lateral-orbital-gyrus",
355
+ "280": "Right-MCgG--middle-cingulate-gyrus",
356
+ "281": "Left-MCgG--middle-cingulate-gyrus",
357
+ "282": "Right-MFC---medial-frontal-cortex",
358
+ "283": "Left-MFC---medial-frontal-cortex",
359
+ "284": "Right-MFG---middle-frontal-gyrus",
360
+ "285": "Left-MFG---middle-frontal-gyrus",
361
+ "286": "Right-MOG---middle-occipital-gyrus",
362
+ "287": "Left-MOG---middle-occipital-gyrus",
363
+ "288": "Right-MOrG--medial-orbital-gyrus",
364
+ "289": "Left-MOrG--medial-orbital-gyrus",
365
+ "290": "Right-MPoG--postcentral-gyrus",
366
+ "291": "Left-MPoG--postcentral-gyrus",
367
+ "292": "Right-MPrG--precentral-gyrus",
368
+ "293": "Left-MPrG--precentral-gyrus",
369
+ "294": "Right-MSFG--superior-frontal-gyrus",
370
+ "295": "Left-MSFG--superior-frontal-gyrus",
371
+ "296": "Right-MTG---middle-temporal-gyrus",
372
+ "297": "Left-MTG---middle-temporal-gyrus",
373
+ "298": "Right-OCP---occipital-pole",
374
+ "299": "Left-OCP---occipital-pole",
375
+ "300": "Right-OFuG--occipital-fusiform-gyrus",
376
+ "301": "Left-OFuG--occipital-fusiform-gyrus",
377
+ "302": "Right-OpIFG-opercular-part-of-the-IFG",
378
+ "303": "Left-OpIFG-opercular-part-of-the-IFG",
379
+ "304": "Right-OrIFG-orbital-part-of-the-IFG",
380
+ "305": "Left-OrIFG-orbital-part-of-the-IFG",
381
+ "306": "Right-PCgG--posterior-cingulate-gyrus",
382
+ "307": "Left-PCgG--posterior-cingulate-gyrus",
383
+ "308": "Right-PCu---precuneus",
384
+ "309": "Left-PCu---precuneus",
385
+ "310": "Right-PHG---parahippocampal-gyrus",
386
+ "311": "Left-PHG---parahippocampal-gyrus",
387
+ "312": "Right-PIns--posterior-insula",
388
+ "313": "Left-PIns--posterior-insula",
389
+ "314": "Right-PO----parietal-operculum",
390
+ "315": "Left-PO----parietal-operculum",
391
+ "316": "Right-PoG---postcentral-gyrus",
392
+ "317": "Left-PoG---postcentral-gyrus",
393
+ "318": "Right-POrG--posterior-orbital-gyrus",
394
+ "319": "Left-POrG--posterior-orbital-gyrus",
395
+ "320": "Right-PP----planum-polare",
396
+ "321": "Left-PP----planum-polare",
397
+ "322": "Right-PrG---precentral-gyrus",
398
+ "323": "Left-PrG---precentral-gyrus",
399
+ "324": "Right-PT----planum-temporale",
400
+ "325": "Left-PT----planum-temporale",
401
+ "326": "Right-SCA---subcallosal-area",
402
+ "327": "Left-SCA---subcallosal-area",
403
+ "328": "Right-SFG---superior-frontal-gyrus",
404
+ "329": "Left-SFG---superior-frontal-gyrus",
405
+ "330": "Right-SMC---supplementary-motor-cortex",
406
+ "331": "Left-SMC---supplementary-motor-cortex",
407
+ "332": "Right-SMG---supramarginal-gyrus",
408
+ "333": "Left-SMG---supramarginal-gyrus",
409
+ "334": "Right-SOG---superior-occipital-gyrus",
410
+ "335": "Left-SOG---superior-occipital-gyrus",
411
+ "336": "Right-SPL---superior-parietal-lobule",
412
+ "337": "Left-SPL---superior-parietal-lobule",
413
+ "338": "Right-STG---superior-temporal-gyrus",
414
+ "339": "Left-STG---superior-temporal-gyrus",
415
+ "340": "Right-TMP---temporal-pole",
416
+ "341": "Left-TMP---temporal-pole",
417
+ "342": "Right-TrIFG-triangular-part-of-the-IFG",
418
+ "343": "Left-TrIFG-triangular-part-of-the-IFG",
419
+ "344": "Right-TTG---transverse-temporal-gyrus",
420
+ "345": "Left-TTG---transverse-temporal-gyrus"
421
+ }
422
+ }
423
+ },
424
+ "everything_labels": {
425
+ "CT_BODY": {
426
+ "0": "background",
427
+ "1": "liver",
428
+ "2": "kidney",
429
+ "3": "spleen",
430
+ "4": "pancreas",
431
+ "5": "right kidney",
432
+ "6": "aorta",
433
+ "7": "inferior vena cava",
434
+ "8": "right adrenal gland",
435
+ "9": "left adrenal gland",
436
+ "10": "gallbladder",
437
+ "11": "esophagus",
438
+ "12": "stomach",
439
+ "13": "duodenum",
440
+ "14": "left kidney",
441
+ "15": "bladder",
442
+ "16": "prostate or uterus",
443
+ "17": "portal vein and splenic vein",
444
+ "18": "rectum",
445
+ "19": "small bowel",
446
+ "20": "lung",
447
+ "21": "bone",
448
+ "22": "brain",
449
+ "23": "lung tumor",
450
+ "24": "pancreatic tumor",
451
+ "25": "hepatic vessel",
452
+ "26": "hepatic tumor",
453
+ "27": "colon cancer primaries",
454
+ "28": "left lung upper lobe",
455
+ "29": "left lung lower lobe",
456
+ "30": "right lung upper lobe",
457
+ "31": "right lung middle lobe",
458
+ "32": "right lung lower lobe",
459
+ "33": "vertebrae L5",
460
+ "34": "vertebrae L4",
461
+ "35": "vertebrae L3",
462
+ "36": "vertebrae L2",
463
+ "37": "vertebrae L1",
464
+ "38": "vertebrae T12",
465
+ "39": "vertebrae T11",
466
+ "40": "vertebrae T10",
467
+ "41": "vertebrae T9",
468
+ "42": "vertebrae T8",
469
+ "43": "vertebrae T7",
470
+ "44": "vertebrae T6",
471
+ "45": "vertebrae T5",
472
+ "46": "vertebrae T4",
473
+ "47": "vertebrae T3",
474
+ "48": "vertebrae T2",
475
+ "49": "vertebrae T1",
476
+ "50": "vertebrae C7",
477
+ "51": "vertebrae C6",
478
+ "52": "vertebrae C5",
479
+ "53": "vertebrae C4",
480
+ "54": "vertebrae C3",
481
+ "55": "vertebrae C2",
482
+ "56": "vertebrae C1",
483
+ "57": "trachea",
484
+ "58": "left iliac artery",
485
+ "59": "right iliac artery",
486
+ "60": "left iliac vena",
487
+ "61": "right iliac vena",
488
+ "62": "colon",
489
+ "63": "left rib 1",
490
+ "64": "left rib 2",
491
+ "65": "left rib 3",
492
+ "66": "left rib 4",
493
+ "67": "left rib 5",
494
+ "68": "left rib 6",
495
+ "69": "left rib 7",
496
+ "70": "left rib 8",
497
+ "71": "left rib 9",
498
+ "72": "left rib 10",
499
+ "73": "left rib 11",
500
+ "74": "left rib 12",
501
+ "75": "right rib 1",
502
+ "76": "right rib 2",
503
+ "77": "right rib 3",
504
+ "78": "right rib 4",
505
+ "79": "right rib 5",
506
+ "80": "right rib 6",
507
+ "81": "right rib 7",
508
+ "82": "right rib 8",
509
+ "83": "right rib 9",
510
+ "84": "right rib 10",
511
+ "85": "right rib 11",
512
+ "86": "right rib 12",
513
+ "87": "left humerus",
514
+ "88": "right humerus",
515
+ "89": "left scapula",
516
+ "90": "right scapula",
517
+ "91": "left clavicula",
518
+ "92": "right clavicula",
519
+ "93": "left femur",
520
+ "94": "right femur",
521
+ "95": "left hip",
522
+ "96": "right hip",
523
+ "97": "sacrum",
524
+ "98": "left gluteus maximus",
525
+ "99": "right gluteus maximus",
526
+ "100": "left gluteus medius",
527
+ "101": "right gluteus medius",
528
+ "102": "left gluteus minimus",
529
+ "103": "right gluteus minimus",
530
+ "104": "left autochthon",
531
+ "105": "right autochthon",
532
+ "106": "left iliopsoas",
533
+ "107": "right iliopsoas",
534
+ "108": "left atrial appendage",
535
+ "109": "brachiocephalic trunk",
536
+ "110": "left brachiocephalic vein",
537
+ "111": "right brachiocephalic vein",
538
+ "112": "left common carotid artery",
539
+ "113": "right common carotid artery",
540
+ "114": "costal cartilages",
541
+ "115": "heart",
542
+ "116": "left kidney cyst",
543
+ "117": "right kidney cyst",
544
+ "118": "prostate",
545
+ "119": "pulmonary vein",
546
+ "120": "skull",
547
+ "121": "spinal cord",
548
+ "122": "sternum",
549
+ "123": "left subclavian artery",
550
+ "124": "right subclavian artery",
551
+ "125": "superior vena cava",
552
+ "126": "thyroid gland",
553
+ "127": "vertebrae S1",
554
+ "128": "bone lesion",
555
+ "129": "kidney mass",
556
+ "130": "liver tumor",
557
+ "131": "vertebrae L6",
558
+ "132": "airway"
559
+ },
560
+ "MRI_BODY": {
561
+ "0": "background",
562
+ "9": "left adrenal gland",
563
+ "8": "right adrenal gland",
564
+ "6": "aorta",
565
+ "104": "left autochthon",
566
+ "105": "right autochthon",
567
+ "22": "brain",
568
+ "91": "left clavicula",
569
+ "92": "right clavicula",
570
+ "62": "colon",
571
+ "13": "duodenum",
572
+ "11": "esophagus",
573
+ "93": "left femur",
574
+ "94": "right femur",
575
+ "10": "gallbladder",
576
+ "98": "left gluteus maximus",
577
+ "99": "right gluteus maximus",
578
+ "100": "left gluteus medius",
579
+ "101": "right gluteus medius",
580
+ "102": "left gluteus minimus",
581
+ "103": "right gluteus minimus",
582
+ "115": "heart",
583
+ "95": "left hip",
584
+ "96": "right hip",
585
+ "87": "left humerus",
586
+ "88": "right humerus",
587
+ "58": "left iliac artery",
588
+ "59": "right iliac artery",
589
+ "60": "left iliac vena",
590
+ "61": "right iliac vena",
591
+ "106": "left iliopsoas",
592
+ "107": "right iliopsoas",
593
+ "7": "inferior vena cava",
594
+ "134": "intervertebral discs",
595
+ "14": "left kidney",
596
+ "5": "right kidney",
597
+ "1": "liver",
598
+ "135": "left lung",
599
+ "136": "right lung",
600
+ "4": "pancreas",
601
+ "17": "portal vein and splenic vein",
602
+ "118": "prostate",
603
+ "97": "sacrum",
604
+ "89": "left scapula",
605
+ "90": "right scapula",
606
+ "19": "small bowel",
607
+ "121": "spinal cord",
608
+ "3": "spleen",
609
+ "12": "stomach",
610
+ "15": "bladder",
611
+ "146": "vertebrae"
612
+ },
613
+ "MRI_BRAIN": {
614
+ "0": "background",
615
+ "214": "3rd-Ventricle",
616
+ "215": "4th-Ventricle",
617
+ "216": "Right-Accumbens-Area",
618
+ "217": "Left-Accumbens-Area",
619
+ "218": "Right-Amygdala",
620
+ "219": "Left-Amygdala",
621
+ "220": "Brain-Stem",
622
+ "221": "Right-Caudate",
623
+ "222": "Left-Caudate",
624
+ "223": "Right-Cerebellum-Exterior",
625
+ "224": "Left-Cerebellum-Exterior",
626
+ "225": "Right-Cerebellum-White-Matter",
627
+ "226": "Left-Cerebellum-White-Matter",
628
+ "227": "Right-Cerebral-White-Matter",
629
+ "228": "Left-Cerebral-White-Matter",
630
+ "229": "Right-Hippocampus",
631
+ "230": "Left-Hippocampus",
632
+ "231": "Right-Inf-Lat-Vent",
633
+ "232": "Left-Inf-Lat-Vent",
634
+ "233": "Right-Lateral-Ventricle",
635
+ "234": "Left-Lateral-Ventricle",
636
+ "235": "Right-Pallidum",
637
+ "236": "Left-Pallidum",
638
+ "237": "Right-Putamen",
639
+ "238": "Left-Putamen",
640
+ "239": "Right-Thalamus-Proper",
641
+ "240": "Left-Thalamus-Proper",
642
+ "241": "Right-Ventral-DC",
643
+ "242": "Left-Ventral-DC",
644
+ "243": "Cerebellar-Vermal-Lobules-I-V",
645
+ "244": "Cerebellar-Vermal-Lobules-VI-VII",
646
+ "245": "Cerebellar-Vermal-Lobules-VIII-X",
647
+ "246": "Left-Basal-Forebrain",
648
+ "247": "Right-Basal-Forebrain",
649
+ "248": "Right-ACgG--anterior-cingulate-gyrus",
650
+ "249": "Left-ACgG--anterior-cingulate-gyrus",
651
+ "250": "Right-AIns--anterior-insula",
652
+ "251": "Left-AIns--anterior-insula",
653
+ "252": "Right-AOrG--anterior-orbital-gyrus",
654
+ "253": "Left-AOrG--anterior-orbital-gyrus",
655
+ "254": "Right-AnG---angular-gyrus",
656
+ "255": "Left-AnG---angular-gyrus",
657
+ "256": "Right-Calc--calcarine-cortex",
658
+ "257": "Left-Calc--calcarine-cortex",
659
+ "258": "Right-CO----central-operculum",
660
+ "259": "Left-CO----central-operculum",
661
+ "260": "Right-Cun---cuneus",
662
+ "261": "Left-Cun---cuneus",
663
+ "262": "Right-Ent---entorhinal-area",
664
+ "263": "Left-Ent---entorhinal-area",
665
+ "264": "Right-FO----frontal-operculum",
666
+ "265": "Left-FO----frontal-operculum",
667
+ "266": "Right-FRP---frontal-pole",
668
+ "267": "Left-FRP---frontal-pole",
669
+ "268": "Right-FuG---fusiform-gyrus",
670
+ "269": "Left-FuG---fusiform-gyrus",
671
+ "270": "Right-GRe---gyrus-rectus",
672
+ "271": "Left-GRe---gyrus-rectus",
673
+ "272": "Right-IOG---inferior-occipital-gyrus",
674
+ "273": "Left-IOG---inferior-occipital-gyrus",
675
+ "274": "Right-ITG---inferior-temporal-gyrus",
676
+ "275": "Left-ITG---inferior-temporal-gyrus",
677
+ "276": "Right-LiG---lingual-gyrus",
678
+ "277": "Left-LiG---lingual-gyrus",
679
+ "278": "Right-LOrG--lateral-orbital-gyrus",
680
+ "279": "Left-LOrG--lateral-orbital-gyrus",
681
+ "280": "Right-MCgG--middle-cingulate-gyrus",
682
+ "281": "Left-MCgG--middle-cingulate-gyrus",
683
+ "282": "Right-MFC---medial-frontal-cortex",
684
+ "283": "Left-MFC---medial-frontal-cortex",
685
+ "284": "Right-MFG---middle-frontal-gyrus",
686
+ "285": "Left-MFG---middle-frontal-gyrus",
687
+ "286": "Right-MOG---middle-occipital-gyrus",
688
+ "287": "Left-MOG---middle-occipital-gyrus",
689
+ "288": "Right-MOrG--medial-orbital-gyrus",
690
+ "289": "Left-MOrG--medial-orbital-gyrus",
691
+ "290": "Right-MPoG--postcentral-gyrus",
692
+ "291": "Left-MPoG--postcentral-gyrus",
693
+ "292": "Right-MPrG--precentral-gyrus",
694
+ "293": "Left-MPrG--precentral-gyrus",
695
+ "294": "Right-MSFG--superior-frontal-gyrus",
696
+ "295": "Left-MSFG--superior-frontal-gyrus",
697
+ "296": "Right-MTG---middle-temporal-gyrus",
698
+ "297": "Left-MTG---middle-temporal-gyrus",
699
+ "298": "Right-OCP---occipital-pole",
700
+ "299": "Left-OCP---occipital-pole",
701
+ "300": "Right-OFuG--occipital-fusiform-gyrus",
702
+ "301": "Left-OFuG--occipital-fusiform-gyrus",
703
+ "302": "Right-OpIFG-opercular-part-of-the-IFG",
704
+ "303": "Left-OpIFG-opercular-part-of-the-IFG",
705
+ "304": "Right-OrIFG-orbital-part-of-the-IFG",
706
+ "305": "Left-OrIFG-orbital-part-of-the-IFG",
707
+ "306": "Right-PCgG--posterior-cingulate-gyrus",
708
+ "307": "Left-PCgG--posterior-cingulate-gyrus",
709
+ "308": "Right-PCu---precuneus",
710
+ "309": "Left-PCu---precuneus",
711
+ "310": "Right-PHG---parahippocampal-gyrus",
712
+ "311": "Left-PHG---parahippocampal-gyrus",
713
+ "312": "Right-PIns--posterior-insula",
714
+ "313": "Left-PIns--posterior-insula",
715
+ "314": "Right-PO----parietal-operculum",
716
+ "315": "Left-PO----parietal-operculum",
717
+ "316": "Right-PoG---postcentral-gyrus",
718
+ "317": "Left-PoG---postcentral-gyrus",
719
+ "318": "Right-POrG--posterior-orbital-gyrus",
720
+ "319": "Left-POrG--posterior-orbital-gyrus",
721
+ "320": "Right-PP----planum-polare",
722
+ "321": "Left-PP----planum-polare",
723
+ "322": "Right-PrG---precentral-gyrus",
724
+ "323": "Left-PrG---precentral-gyrus",
725
+ "324": "Right-PT----planum-temporale",
726
+ "325": "Left-PT----planum-temporale",
727
+ "326": "Right-SCA---subcallosal-area",
728
+ "327": "Left-SCA---subcallosal-area",
729
+ "328": "Right-SFG---superior-frontal-gyrus",
730
+ "329": "Left-SFG---superior-frontal-gyrus",
731
+ "330": "Right-SMC---supplementary-motor-cortex",
732
+ "331": "Left-SMC---supplementary-motor-cortex",
733
+ "332": "Right-SMG---supramarginal-gyrus",
734
+ "333": "Left-SMG---supramarginal-gyrus",
735
+ "334": "Right-SOG---superior-occipital-gyrus",
736
+ "335": "Left-SOG---superior-occipital-gyrus",
737
+ "336": "Right-SPL---superior-parietal-lobule",
738
+ "337": "Left-SPL---superior-parietal-lobule",
739
+ "338": "Right-STG---superior-temporal-gyrus",
740
+ "339": "Left-STG---superior-temporal-gyrus",
741
+ "340": "Right-TMP---temporal-pole",
742
+ "341": "Left-TMP---temporal-pole",
743
+ "342": "Right-TrIFG-triangular-part-of-the-IFG",
744
+ "343": "Left-TrIFG-triangular-part-of-the-IFG",
745
+ "344": "Right-TTG---transverse-temporal-gyrus",
746
+ "345": "Left-TTG---transverse-temporal-gyrus"
747
+ }
748
+ }
749
+ }
750
+ }
scripts/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # from .evaluator import EnsembleEvaluator, Evaluator, SupervisedEvaluator
13
+ # from .multi_gpu_supervised_trainer import create_multigpu_supervised_evaluator, create_multigpu_supervised_trainer
14
+
15
+ from .early_stop_score_function import score_function
scripts/early_stop_score_function.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ import torch.distributed as dist
5
+
6
+
7
+ def score_function(engine):
8
+ val_metric = engine.state.metrics["val_mean_dice"]
9
+ if dist.is_initialized():
10
+ device = torch.device("cuda:" + os.environ["LOCAL_RANK"])
11
+ val_metric = torch.tensor([val_metric]).to(device)
12
+ dist.all_reduce(val_metric, op=dist.ReduceOp.SUM)
13
+ val_metric /= dist.get_world_size()
14
+ return val_metric.item()
15
+ return val_metric
scripts/evaluator.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
15
+
16
+ import numpy as np
17
+ import torch
18
+ from monai.engines.evaluator import SupervisedEvaluator
19
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
20
+ from monai.inferers import Inferer, SimpleInferer
21
+ from monai.transforms import Transform, reset_ops_id
22
+ from monai.utils import ForwardMode, IgniteInfo, RankFilter, min_version, optional_import
23
+ from monai.utils.enums import CommonKeys as Keys
24
+ from torch.utils.data import DataLoader
25
+
26
+ rearrange, _ = optional_import("einops", name="rearrange")
27
+
28
+ if TYPE_CHECKING:
29
+ from ignite.engine import Engine, EventEnum
30
+ from ignite.metrics import Metric
31
+ else:
32
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
33
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
34
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
35
+
36
+ __all__ = ["Vista3dEvaluator"]
37
+
38
+
39
+ class Vista3dEvaluator(SupervisedEvaluator):
40
+ """
41
+ Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
42
+ Args:
43
+ device: an object representing the device on which to run.
44
+ val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
45
+ network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
46
+ epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
47
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
48
+ with respect to the host. For other cases, this argument has no effect.
49
+ prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
50
+ from `engine.state.batch` for every iteration, for more details please refer to:
51
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
52
+ iteration_update: the callable function for every iteration, expect to accept `engine`
53
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
54
+ if not provided, use `self._iteration()` instead. for more details please refer to:
55
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
56
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
57
+ postprocessing: execute additional transformation for the model output data.
58
+ Typically, several Tensor based transforms composed by `Compose`.
59
+ key_val_metric: compute metric when every iteration completed, and save average value to
60
+ engine.state.metrics when epoch completed. key_val_metric is the main metric to compare and save the
61
+ checkpoint into files.
62
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
63
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
64
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
65
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
66
+ val_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
67
+ CheckpointHandler, StatsHandler, etc.
68
+ amp: whether to enable auto-mixed-precision evaluation, default is False.
69
+ mode: model forward mode during evaluation, should be 'eval' or 'train',
70
+ which maps to `model.eval()` or `model.train()`, default to 'eval'.
71
+ event_names: additional custom ignite events that will register to the engine.
72
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
73
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
74
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
75
+ #ignite.engine.engine.Engine.register_events.
76
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
77
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
78
+ default to `True`.
79
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
80
+ `device`, `non_blocking`.
81
+ amp_kwargs: dict of the args for `torch.amp.autocast()` API, for more details:
82
+ https://pytorch.org/docs/stable/amp.html#torch.amp.autocast.
83
+ """
84
+
85
+ def __init__(
86
+ self,
87
+ device: torch.device,
88
+ val_data_loader: Iterable | DataLoader,
89
+ network: torch.nn.Module,
90
+ epoch_length: int | None = None,
91
+ non_blocking: bool = False,
92
+ prepare_batch: Callable = default_prepare_batch,
93
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
94
+ inferer: Inferer | None = None,
95
+ postprocessing: Transform | None = None,
96
+ key_val_metric: dict[str, Metric] | None = None,
97
+ additional_metrics: dict[str, Metric] | None = None,
98
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
99
+ val_handlers: Sequence | None = None,
100
+ amp: bool = False,
101
+ mode: ForwardMode | str = ForwardMode.EVAL,
102
+ event_names: list[str | EventEnum | type[EventEnum]] | None = None,
103
+ event_to_attr: dict | None = None,
104
+ decollate: bool = True,
105
+ to_kwargs: dict | None = None,
106
+ amp_kwargs: dict | None = None,
107
+ hyper_kwargs: dict | None = None,
108
+ ) -> None:
109
+ super().__init__(
110
+ device=device,
111
+ val_data_loader=val_data_loader,
112
+ network=network,
113
+ epoch_length=epoch_length,
114
+ non_blocking=non_blocking,
115
+ prepare_batch=prepare_batch,
116
+ iteration_update=iteration_update,
117
+ postprocessing=postprocessing,
118
+ key_val_metric=key_val_metric,
119
+ additional_metrics=additional_metrics,
120
+ metric_cmp_fn=metric_cmp_fn,
121
+ val_handlers=val_handlers,
122
+ amp=amp,
123
+ mode=mode,
124
+ event_names=event_names,
125
+ event_to_attr=event_to_attr,
126
+ decollate=decollate,
127
+ to_kwargs=to_kwargs,
128
+ amp_kwargs=amp_kwargs,
129
+ )
130
+
131
+ self.network = network
132
+ self.device = device
133
+ self.inferer = SimpleInferer() if inferer is None else inferer
134
+ self.hyper_kwargs = hyper_kwargs
135
+ self.logger.addFilter(RankFilter())
136
+
137
+ def transform_points(self, point, affine):
138
+ """transform point to the coordinates of the transformed image
139
+ point: numpy array [bs, N, 3]
140
+ """
141
+ bs, n = point.shape[:2]
142
+ point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1)
143
+ point = rearrange(point, "b n d -> d (b n)")
144
+ point = affine @ point
145
+ point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
146
+ return point
147
+
148
+ def check_prompts_format(self, label_prompt, points, point_labels):
149
+ """check the format of user prompts
150
+ label_prompt: [1,2,3,4,...,B] List of tensors
151
+ points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object
152
+ point_labels: [[1,1,0,...]] List of scalar that matches number of points
153
+ """
154
+ # check prompt is given
155
+ if label_prompt is None and points is None:
156
+ everything_labels = self.hyper_kwargs.get("everything_labels", None)
157
+ if everything_labels is not None:
158
+ label_prompt = [torch.tensor(_) for _ in everything_labels]
159
+ return label_prompt, points, point_labels
160
+ else:
161
+ raise ValueError("Prompt must be given for inference.")
162
+ # check label_prompt
163
+ if label_prompt is not None:
164
+ if isinstance(label_prompt, list):
165
+ if not np.all([len(_) == 1 for _ in label_prompt]):
166
+ raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].")
167
+ if not np.all([(x < 255).item() for x in label_prompt]):
168
+ raise ValueError("Current bundle only supports label prompt smaller than 255.")
169
+ if points is None:
170
+ supported_list = list({i + 1 for i in range(132)} - {16, 18, 129, 130, 131})
171
+ if not np.all([x in supported_list for x in label_prompt]):
172
+ raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.")
173
+ else:
174
+ raise ValueError("Label prompt must be a list, [1,2,3,4,...,].")
175
+ # check points
176
+ if points is not None:
177
+ if point_labels is None:
178
+ raise ValueError("Point labels must be given if points are given.")
179
+ if not np.all([len(_) == 3 for _ in points]):
180
+ raise ValueError("Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].")
181
+ if len(points) != len(point_labels):
182
+ raise ValueError("Points must match point labels.")
183
+ if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]):
184
+ raise ValueError("Point labels can only be -1,0,1 and 2,3 for special flags.")
185
+ if label_prompt is not None and points is not None:
186
+ if len(label_prompt) != 1:
187
+ raise ValueError("Label prompt can only be a single object if provided with point prompts.")
188
+ # check point_labels
189
+ if point_labels is not None:
190
+ if points is None:
191
+ raise ValueError("Points must be given if point labels are given.")
192
+ return label_prompt, points, point_labels
193
+
194
+ def _iteration(self, engine: SupervisedEvaluator, batchdata: dict[str, torch.Tensor]) -> dict:
195
+ """
196
+ callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
197
+ Return below items in a dictionary:
198
+ - IMAGE: image Tensor data for model input, already moved to device.
199
+ - LABEL: label Tensor data corresponding to the image, already moved to device.
200
+ - PRED: prediction result of model.
201
+
202
+ Args:
203
+ engine: `SupervisedEvaluator` to execute operation for an iteration.
204
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
205
+
206
+ Raises:
207
+ ValueError: When ``batchdata`` is None.
208
+
209
+ """
210
+ if batchdata is None:
211
+ raise ValueError("Must provide batch data for current iteration.")
212
+ label_set = engine.hyper_kwargs.get("label_set", None)
213
+ # this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points
214
+ val_label_set = engine.hyper_kwargs.get("val_label_set", label_set)
215
+ # If user provide prompts in the inference, input image must contain original affine.
216
+ # the point coordinates are from the original_affine space, while image here is after preprocess transforms.
217
+ if engine.hyper_kwargs["user_prompt"]:
218
+ inputs, label_prompt, points, point_labels = (
219
+ batchdata["image"],
220
+ batchdata.get("label_prompt", None),
221
+ batchdata.get("points", None),
222
+ batchdata.get("point_labels", None),
223
+ )
224
+ labels = None
225
+ label_prompt, points, point_labels = self.check_prompts_format(label_prompt, points, point_labels)
226
+ inputs = inputs.to(engine.device)
227
+ # For N foreground object, label_prompt is [1, N], but the batch number 1 needs to be removed. Convert to [N, 1]
228
+ label_prompt = (
229
+ torch.as_tensor([label_prompt]).to(inputs.device)[0].unsqueeze(-1) if label_prompt is not None else None
230
+ )
231
+ # For points, the size can only be [1, K, 3], where K is the number of points for this single foreground object.
232
+ if points is not None:
233
+ points = torch.as_tensor([points])
234
+ points = self.transform_points(
235
+ points, np.linalg.inv(inputs.affine[0]) @ inputs.meta["original_affine"][0].numpy()
236
+ )
237
+ points = torch.from_numpy(points).to(inputs.device)
238
+ point_labels = torch.as_tensor([point_labels]).to(inputs.device) if point_labels is not None else None
239
+
240
+ # If validation with ground truth label available.
241
+ else:
242
+ inputs, labels = engine.prepare_batch(
243
+ batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs
244
+ )
245
+ # create label prompt, this should be consistent with the label prompt used for training.
246
+ if label_set is None:
247
+ output_classes = engine.hyper_kwargs["output_classes"]
248
+ label_set = np.arange(output_classes).tolist()
249
+ label_prompt = torch.tensor(label_set).to(engine.state.device).unsqueeze(-1)
250
+ # point prompt is generated withing vista3d, provide empty points
251
+ points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device)
252
+ point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device)
253
+ # validation for either auto or point.
254
+ if engine.hyper_kwargs.get("val_head", "auto") == "auto":
255
+ # automatic only validation
256
+ # remove val_label_set, vista3d will not sample points from gt labels.
257
+ val_label_set = None
258
+ else:
259
+ # point only validation
260
+ label_prompt = None
261
+
262
+ # put iteration outputs into engine.state
263
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
264
+ # execute forward computation
265
+ with engine.mode(engine.network):
266
+ if engine.amp:
267
+ with torch.amp.autocast("cuda", **engine.amp_kwargs):
268
+ engine.state.output[Keys.PRED] = engine.inferer(
269
+ inputs=inputs,
270
+ network=engine.network,
271
+ point_coords=points,
272
+ point_labels=point_labels,
273
+ class_vector=label_prompt,
274
+ labels=labels,
275
+ label_set=val_label_set,
276
+ )
277
+ else:
278
+ engine.state.output[Keys.PRED] = engine.inferer(
279
+ inputs=inputs,
280
+ network=engine.network,
281
+ point_coords=points,
282
+ point_labels=point_labels,
283
+ class_vector=label_prompt,
284
+ labels=labels,
285
+ label_set=val_label_set,
286
+ )
287
+ inputs = reset_ops_id(inputs)
288
+ # Add dim 0 for decollate batch
289
+ engine.state.output["label_prompt"] = label_prompt.unsqueeze(0) if label_prompt is not None else None
290
+ engine.state.output["points"] = points.unsqueeze(0) if points is not None else None
291
+ engine.state.output["point_labels"] = point_labels.unsqueeze(0) if point_labels is not None else None
292
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
293
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
294
+ if torch.cuda.is_available():
295
+ torch.cuda.empty_cache()
296
+
297
+ return engine.state.output
scripts/inferer.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ import copy
13
+ from typing import List, Union
14
+
15
+ import torch
16
+ from monai.apps.vista3d.inferer import point_based_window_inferer
17
+ from monai.inferers import Inferer, SlidingWindowInfererAdapt
18
+ from torch import Tensor
19
+
20
+
21
+ class Vista3dInferer(Inferer):
22
+ """
23
+ Vista3D Inferer
24
+
25
+ Args:
26
+ roi_size: the sliding window patch size.
27
+ overlap: sliding window overlap ratio.
28
+ """
29
+
30
+ def __init__(self, roi_size, overlap, use_point_window=False, sw_batch_size=1) -> None:
31
+ Inferer.__init__(self)
32
+ self.roi_size = roi_size
33
+ self.overlap = overlap
34
+ self.sw_batch_size = sw_batch_size
35
+ self.use_point_window = use_point_window
36
+
37
+ def __call__(
38
+ self,
39
+ inputs: Union[List[Tensor], Tensor],
40
+ network,
41
+ point_coords,
42
+ point_labels,
43
+ class_vector,
44
+ labels=None,
45
+ label_set=None,
46
+ prev_mask=None,
47
+ ):
48
+ """
49
+ Unified callable function API of Inferers.
50
+ Notice: The point_based_window_inferer currently only supports SINGLE OBJECT INFERENCE with B=1.
51
+ It only used in interactive segmentation.
52
+
53
+ Args:
54
+ inputs: input tensor images.
55
+ network: vista3d model.
56
+ point_coords: point click coordinates. [B, N, 3].
57
+ point_labels: point click labels (0 for negative, 1 for positive) [B, N].
58
+ class_vector: class vector of length B.
59
+ labels: groundtruth labels. Used for sampling validation points.
60
+ label_set: [0,1,2,3,...,output_classes].
61
+ prev_mask: [1, B, H, W, D], THE VALUE IS BEFORE SIGMOID!
62
+
63
+ """
64
+ prompt_class = copy.deepcopy(class_vector)
65
+ if class_vector is not None and (point_labels is not None and torch.any(point_labels != -1)):
66
+ # Only when user perform zero-shot interactive during inference. Remove the class vector
67
+ # and keep the prompt_class to inform the model about the zero-shot. During finetuning,
68
+ # a novel class > last_supported is possible and should be taken care of.
69
+ # This check should be moved to evaluator and prompt_class should be added as input to the inferer.
70
+ if hasattr(network, "point_head"):
71
+ point_head = network.point_head
72
+ elif hasattr(network, "module") and hasattr(network.module, "point_head"):
73
+ point_head = network.module.point_head
74
+ else:
75
+ raise AttributeError("Network does not have attribute 'point_head'.")
76
+
77
+ if torch.any(class_vector > point_head.last_supported):
78
+ class_vector = None
79
+ val_outputs = None
80
+ torch.cuda.empty_cache()
81
+ if self.use_point_window and point_coords is not None:
82
+ if isinstance(inputs, list):
83
+ device = inputs[0].device
84
+ else:
85
+ device = inputs.device
86
+ val_outputs = point_based_window_inferer(
87
+ inputs=inputs,
88
+ roi_size=self.roi_size,
89
+ sw_batch_size=self.sw_batch_size,
90
+ transpose=True,
91
+ with_coord=True,
92
+ predictor=network,
93
+ mode="gaussian",
94
+ sw_device=device,
95
+ device=device,
96
+ overlap=self.overlap,
97
+ point_coords=point_coords,
98
+ point_labels=point_labels,
99
+ class_vector=class_vector,
100
+ prompt_class=prompt_class,
101
+ prev_mask=prev_mask,
102
+ labels=labels,
103
+ label_set=label_set,
104
+ )
105
+ else:
106
+ val_outputs = SlidingWindowInfererAdapt(
107
+ roi_size=self.roi_size, sw_batch_size=self.sw_batch_size, with_coord=True, padding_mode="replicate"
108
+ )(
109
+ inputs,
110
+ network,
111
+ transpose=True,
112
+ point_coords=point_coords,
113
+ point_labels=point_labels,
114
+ class_vector=class_vector,
115
+ prompt_class=prompt_class,
116
+ prev_mask=prev_mask,
117
+ labels=labels,
118
+ label_set=label_set,
119
+ )
120
+ return val_outputs
scripts/trainer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from __future__ import annotations
13
+
14
+ from typing import TYPE_CHECKING, Any, Callable, Iterable, Sequence
15
+
16
+ import numpy as np
17
+ import torch
18
+ from monai.apps.vista3d.sampler import sample_prompt_pairs
19
+ from monai.engines.trainer import Trainer
20
+ from monai.engines.utils import IterationEvents, default_metric_cmp_fn, default_prepare_batch
21
+ from monai.inferers import Inferer, SimpleInferer
22
+ from monai.transforms import Transform
23
+ from monai.utils import IgniteInfo, RankFilter, min_version, optional_import
24
+ from monai.utils.enums import CommonKeys as Keys
25
+ from torch.optim.optimizer import Optimizer
26
+ from torch.utils.data import DataLoader
27
+
28
+ if TYPE_CHECKING:
29
+ from ignite.engine import Engine, EventEnum
30
+ from ignite.metrics import Metric
31
+ else:
32
+ Engine, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Engine")
33
+ Metric, _ = optional_import("ignite.metrics", IgniteInfo.OPT_IMPORT_VERSION, min_version, "Metric")
34
+ EventEnum, _ = optional_import("ignite.engine", IgniteInfo.OPT_IMPORT_VERSION, min_version, "EventEnum")
35
+
36
+ __all__ = ["Vista3dTrainer"]
37
+
38
+
39
+ class Vista3dTrainer(Trainer):
40
+ """
41
+ Supervised detection training method with image and label, inherits from ``Trainer`` and ``Workflow``.
42
+ Args:
43
+ device: an object representing the device on which to run.
44
+ max_epochs: the total epoch number for trainer to run.
45
+ train_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
46
+ detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
47
+ optimizer: the optimizer associated to the detector, should be regular PyTorch optimizer from `torch.optim`
48
+ or its subclass.
49
+ epoch_length: number of iterations for one epoch, default to `len(train_data_loader)`.
50
+ non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
51
+ with respect to the host. For other cases, this argument has no effect.
52
+ prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
53
+ from `engine.state.batch` for every iteration, for more details please refer to:
54
+ https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
55
+ iteration_update: the callable function for every iteration, expect to accept `engine`
56
+ and `engine.state.batch` as inputs, return data will be stored in `engine.state.output`.
57
+ if not provided, use `self._iteration()` instead. for more details please refer to:
58
+ https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html.
59
+ inferer: inference method that execute model forward on input data, like: SlidingWindow, etc.
60
+ postprocessing: execute additional transformation for the model output data.
61
+ Typically, several Tensor based transforms composed by `Compose`.
62
+ key_train_metric: compute metric when every iteration completed, and save average value to
63
+ engine.state.metrics when epoch completlabel_set = np.arange(output_classes).tolist().
64
+ key_train_metric is the main metric to compare and save the checkpoint into files.
65
+ additional_metrics: more Ignite metrics that also attach to Ignite Engine.
66
+ metric_cmp_fn: function to compare current key metric with previous best key metric value,
67
+ it must accept 2 args (current_metric, previous_best) and return a bool result: if `True`, will update
68
+ `best_metric` and `best_metric_epoch` with current metric and epoch, default to `greater than`.
69
+ train_handlers: every handler is a set of Ignite Event-Handlers, must have `attach` function, like:
70
+ CheckpointHandler, StatsHandler, etc.
71
+ amp: whether to enable auto-mixed-precision training, default is False.
72
+ event_names: additional custom ignite events that will register to the engine.
73
+ new events can be a list of str or `ignite.engine.events.EventEnum`.
74
+ event_to_attr: a dictionary to map an event to a state attribute, then add to `engine.state`.
75
+ for more details, check: https://pytorch.org/ignite/generated/ignite.engine.engine.Engine.html
76
+ #ignite.engine.engine.Engine.register_events.
77
+ decollate: whether to decollate the batch-first data to a list of data after model computation,
78
+ recommend `decollate=True` when `postprocessing` uses components from `monai.transforms`.
79
+ default to `True`.
80
+ optim_set_to_none: when calling `optimizer.zero_grad()`, instead of setting to zero, set the grads to None.
81
+ more details: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html.
82
+ to_kwargs: dict of other args for `prepare_batch` API when converting the input data, except for
83
+ `device`, `non_blocking`.
84
+ amp_kwargs: dict of the args for `torch.amp.autocast()` API, for more details:
85
+ https://pytorch.org/docs/stable/amp.html#torch.amp.autocast.
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ device: torch.device,
91
+ max_epochs: int,
92
+ train_data_loader: Iterable | DataLoader,
93
+ network: torch.nn.Module,
94
+ optimizer: Optimizer,
95
+ loss_function: Callable,
96
+ epoch_length: int | None = None,
97
+ non_blocking: bool = False,
98
+ prepare_batch: Callable = default_prepare_batch,
99
+ iteration_update: Callable[[Engine, Any], Any] | None = None,
100
+ inferer: Inferer | None = None,
101
+ postprocessing: Transform | None = None,
102
+ key_train_metric: dict[str, Metric] | None = None,
103
+ additional_metrics: dict[str, Metric] | None = None,
104
+ metric_cmp_fn: Callable = default_metric_cmp_fn,
105
+ train_handlers: Sequence | None = None,
106
+ amp: bool = False,
107
+ event_names: list[str | EventEnum] | None = None,
108
+ event_to_attr: dict | None = None,
109
+ decollate: bool = True,
110
+ optim_set_to_none: bool = False,
111
+ to_kwargs: dict | None = None,
112
+ amp_kwargs: dict | None = None,
113
+ hyper_kwargs: dict | None = None,
114
+ ) -> None:
115
+ super().__init__(
116
+ device=device,
117
+ max_epochs=max_epochs,
118
+ data_loader=train_data_loader,
119
+ epoch_length=epoch_length,
120
+ non_blocking=non_blocking,
121
+ prepare_batch=prepare_batch,
122
+ iteration_update=iteration_update,
123
+ postprocessing=postprocessing,
124
+ key_metric=key_train_metric,
125
+ additional_metrics=additional_metrics,
126
+ metric_cmp_fn=metric_cmp_fn,
127
+ handlers=train_handlers,
128
+ amp=amp,
129
+ event_names=event_names,
130
+ event_to_attr=event_to_attr,
131
+ decollate=decollate,
132
+ to_kwargs=to_kwargs,
133
+ amp_kwargs=amp_kwargs,
134
+ )
135
+
136
+ self.network = network
137
+ self.optimizer = optimizer
138
+ self.loss_function = loss_function
139
+ self.inferer = SimpleInferer() if inferer is None else inferer
140
+ self.optim_set_to_none = optim_set_to_none
141
+ self.hyper_kwargs = hyper_kwargs
142
+ self.logger.addFilter(RankFilter())
143
+
144
+ def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
145
+ """
146
+ Callback function for the Supervised Training processing logic of 1 iteration in Ignite Engine.
147
+ Return below items in a dictionary:
148
+ - IMAGE: image Tensor data for model input, already moved to device.
149
+ Args:
150
+ engine: `Vista3DTrainer` to execute operation for an iteration.
151
+ batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
152
+ Raises:
153
+ ValueError: When ``batchdata`` is None.
154
+ """
155
+
156
+ if batchdata is None:
157
+ raise ValueError("Must provide batch data for current iteration.")
158
+
159
+ inputs, labels = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
160
+ engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: labels}
161
+
162
+ label_set = engine.hyper_kwargs["label_set"]
163
+ output_classes = engine.hyper_kwargs["output_classes"]
164
+ if label_set is None:
165
+ label_set = np.arange(output_classes).tolist()
166
+ label_prompt, point, point_label, prompt_class = sample_prompt_pairs(
167
+ labels,
168
+ label_set,
169
+ image_size=engine.hyper_kwargs["patch_size"],
170
+ max_point=engine.hyper_kwargs["max_point"],
171
+ max_prompt=engine.hyper_kwargs["max_prompt"],
172
+ max_backprompt=engine.hyper_kwargs["max_backprompt"],
173
+ max_foreprompt=engine.hyper_kwargs["max_foreprompt"],
174
+ drop_label_prob=engine.hyper_kwargs["drop_label_prob"],
175
+ drop_point_prob=engine.hyper_kwargs["drop_point_prob"],
176
+ include_background=not engine.hyper_kwargs["exclude_background"],
177
+ )
178
+
179
+ def _compute_pred_loss():
180
+ outputs = engine.network(
181
+ input_images=inputs, point_coords=point, point_labels=point_label, class_vector=label_prompt
182
+ )
183
+ # engine.state.output[Keys.PRED] = outputs
184
+ engine.fire_event(IterationEvents.FORWARD_COMPLETED)
185
+ loss, loss_n = torch.tensor(0.0, device=engine.state.device), torch.tensor(0.0, device=engine.state.device)
186
+ for id in range(len(prompt_class)):
187
+ loss += engine.loss_function(outputs[[id]].float(), labels == prompt_class[id])
188
+ loss_n += 1.0
189
+ loss /= max(loss_n, 1.0)
190
+ engine.state.output[Keys.LOSS] = loss
191
+ outputs = None
192
+ torch.cuda.empty_cache()
193
+ engine.fire_event(IterationEvents.LOSS_COMPLETED)
194
+
195
+ engine.network.train()
196
+ engine.optimizer.zero_grad(set_to_none=engine.optim_set_to_none)
197
+
198
+ if engine.amp and engine.scaler is not None:
199
+ with torch.amp.autocast("cuda", **engine.amp_kwargs):
200
+ _compute_pred_loss()
201
+ engine.scaler.scale(engine.state.output[Keys.LOSS]).backward()
202
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
203
+ engine.scaler.step(engine.optimizer)
204
+ engine.scaler.update()
205
+ else:
206
+ _compute_pred_loss()
207
+ engine.state.output[Keys.LOSS].backward()
208
+ engine.fire_event(IterationEvents.BACKWARD_COMPLETED)
209
+ engine.optimizer.step()
210
+ engine.fire_event(IterationEvents.MODEL_COMPLETED)
211
+ return engine.state.output
vista3d_config.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+
4
+ class VISTA3DConfig(PretrainedConfig):
5
+ """Configuration class for vista3d"""
6
+
7
+ model_type = "VISTA3D"
8
+
9
+ def __init__(self, encoder_embed_dim: int = 48, input_channels: int = 1, **kwargs):
10
+ """
11
+ Set the hyperparameters for the VISTA3D model.
12
+
13
+ Parameters:
14
+ input_channels: channel of input images.
15
+ encoder_embed_dim: the encoder_embed_dim of the VISTA3D model.
16
+ """
17
+ self.input_channels = input_channels
18
+ self.encoder_embed_dim = encoder_embed_dim
19
+ super().__init__(**kwargs)
vista3d_model.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import monai.networks.nets
4
+ import torch
5
+ from transformers import AutoConfig, AutoModel, PreTrainedModel
6
+ from vista3d_config import VISTA3DConfig
7
+
8
+
9
+ class VISTA3DModel(PreTrainedModel):
10
+ """VISTA3D model for hugging face"""
11
+
12
+ config_class = VISTA3DConfig
13
+
14
+ def __init__(self, config):
15
+ super().__init__(config)
16
+ if config.model_type == "VISTA3D":
17
+ self.network = monai.networks.nets.vista3d132(
18
+ encoder_embed_dim=config.encoder_embed_dim,
19
+ in_channels=config.input_channels,
20
+ )
21
+
22
+ def forward(self, input):
23
+ return self.network(input)
24
+
25
+
26
+ def register_my_model():
27
+ """Utility function to register VISTA3D model so that it can be instantiate by the AutoModel function."""
28
+ AutoConfig.register("VISTA3D", VISTA3DConfig)
29
+ AutoModel.register(VISTA3DConfig, VISTA3DModel)
30
+
31
+
32
+ if __name__ == "__main__":
33
+ FILE_PATH = os.path.dirname(__file__)
34
+ MODEL_WEIGHT_PATH = os.path.join(FILE_PATH, "models/model.pt")
35
+ MODEL_PATH = os.path.join(FILE_PATH, "vista3d_pretrained_model")
36
+ config = VISTA3DConfig()
37
+ hugging_face_model = VISTA3DModel(config)
38
+ hugging_face_model.network.load_state_dict(torch.load(MODEL_WEIGHT_PATH))
39
+ hugging_face_model.save_pretrained(MODEL_PATH)
vista3d_pipeline.py ADDED
@@ -0,0 +1,462 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import logging
4
+ import os
5
+ import pathlib
6
+ from typing import Sequence
7
+
8
+ import numpy as np
9
+ import torch
10
+ from monai.apps.vista3d.transforms import VistaPostTransformd, VistaPreTransformd
11
+ from monai.data.utils import decollate_batch, list_data_collate
12
+ from monai.networks.utils import eval_mode, train_mode
13
+ from monai.transforms import (
14
+ CastToTyped,
15
+ Compose,
16
+ CropForegroundd,
17
+ EnsureChannelFirstd,
18
+ EnsureTyped,
19
+ Invertd,
20
+ Lambdad,
21
+ LoadImaged,
22
+ Orientationd,
23
+ SaveImaged,
24
+ ScaleIntensityRangePercentilesd,
25
+ Spacingd,
26
+ reset_ops_id,
27
+ )
28
+ from monai.utils import ForwardMode, optional_import, set_determinism
29
+ from monai.utils.enums import CommonKeys as Keys
30
+ from monai.utils.module import look_up_option
31
+ from scripts.inferer import Vista3dInferer
32
+ from transformers import AutoModel, Pipeline
33
+ from transformers.pipelines import PIPELINE_REGISTRY
34
+
35
+ rearrange, _ = optional_import("einops", name="rearrange")
36
+
37
+ FILE_PATH = os.path.dirname(__file__)
38
+
39
+
40
+ logging.basicConfig(
41
+ level=logging.INFO,
42
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
43
+ datefmt="%Y-%m-%d %H:%M:%S",
44
+ )
45
+ logger = logging.getLogger(__name__)
46
+
47
+
48
+ class VISTA3DPipeline(Pipeline):
49
+ """Define the VISTA3D pipeline."""
50
+
51
+ PREPROCESSING_EXTRA_ARGS = [
52
+ "image_key",
53
+ "resample_spacing",
54
+ "metadata_path",
55
+ "load_image",
56
+ ]
57
+ INFERENCE_EXTRA_ARGS = [
58
+ "mode",
59
+ "amp",
60
+ "hyper_kwargs",
61
+ "roi_size",
62
+ "overlap",
63
+ "sw_batch_size",
64
+ "use_point_window",
65
+ ]
66
+ POSTPROCESSING_EXTRA_ARGS = [
67
+ "pred_key",
68
+ "image_key",
69
+ "output_dir",
70
+ "output_ext",
71
+ "output_postfix",
72
+ "separate_folder",
73
+ "save_output",
74
+ ]
75
+ EVERYTHING_LABEL_CT = list(
76
+ set([i + 1 for i in range(132)])
77
+ - set([2, 16, 18, 20, 21, 23, 24, 25, 26, 27, 128, 129, 130, 131, 132])
78
+ )
79
+ EVERYTHING_LABEL_MRI = [1, 3, 4, 5, 6, 7, 8, 9, 10, 11,\
80
+ 12, 13, 14, 15, 17, 19, 22, 58, 59,\
81
+ 60, 61, 62, 87, 88, 89, 90, 91, 92,\
82
+ 93, 94, 95, 96, 97, 98, 99, 100, 101,\
83
+ 102, 103, 104, 105, 106, 107, 115, 118,\
84
+ 121, 134, 135, 136, 146]
85
+ EVERYTHING_LABEL_BRAIN = list(range(214,346))
86
+
87
+ def __init__(self, model, **kwargs):
88
+ super().__init__(model, **kwargs)
89
+ self.preprocessing_transforms = self._init_preprocessing_transforms(
90
+ **self._preprocess_params
91
+ )
92
+ self.inferer = self._init_inferer(**self._forward_params)
93
+ self.postprocessing_transforms = self._init_postprocessing_transforms(
94
+ **self._postprocess_params
95
+ )
96
+
97
+ def _init_inferer(
98
+ self,
99
+ roi_size: Sequence = (192, 192, 128),
100
+ overlap: float = 0.3,
101
+ sw_batch_size: int = 1,
102
+ use_point_window: bool = True,
103
+ ):
104
+ return Vista3dInferer(
105
+ roi_size=roi_size,
106
+ overlap=overlap,
107
+ use_point_window=use_point_window,
108
+ sw_batch_size=sw_batch_size,
109
+ )
110
+
111
+ def _init_preprocessing_transforms(
112
+ self,
113
+ image_key: str = "image",
114
+ resample_spacing: Sequence = (1.5, 1.5, 1.5),
115
+ metadata_path: str = os.path.join(FILE_PATH, "metadata.json"),
116
+ load_image: bool = True,
117
+ ):
118
+ device = self.device
119
+ subclass = {
120
+ "2": [14, 5],
121
+ "20": [28, 29, 30, 31, 32],
122
+ "21": list(range(33, 57)) + list(range(63, 98)) + [114, 120, 122],
123
+ }
124
+ metadata = json.loads(pathlib.Path(metadata_path).read_text())
125
+ labels_dict = metadata["network_data_format"]["outputs"]["pred"]["channel_def"]
126
+ preprocessing_list = [
127
+ LoadImaged(keys=image_key, image_only=True),
128
+ EnsureChannelFirstd(keys=image_key),
129
+ EnsureTyped(keys=image_key, device=device, track_meta=True),
130
+ Spacingd(keys=image_key, pixdim=resample_spacing, mode="bilinear"),
131
+ CropForegroundd(
132
+ keys=image_key, allow_smaller=True, margin=10, source_key=image_key
133
+ ),
134
+ VistaPreTransformd(
135
+ keys=image_key, subclass=subclass, labels_dict=labels_dict
136
+ ),
137
+ ScaleIntensityRangePercentilesd(
138
+ keys=image_key,
139
+ lower=1,
140
+ upper=99,
141
+ b_min=0,
142
+ b_max=1,
143
+ clip=True,
144
+ ),
145
+ Orientationd(keys=image_key, axcodes="RAS"),
146
+ CastToTyped(keys=image_key, dtype=torch.float32),
147
+ ]
148
+ if not load_image:
149
+ preprocessing_list.pop(0)
150
+
151
+ preprocessing_transforms = Compose(preprocessing_list)
152
+ return preprocessing_transforms
153
+
154
+ def _init_postprocessing_transforms(
155
+ self,
156
+ pred_key: str = "pred",
157
+ image_key: str = "image",
158
+ output_dir: str = "output_directory",
159
+ output_ext: str = ".nii.gz",
160
+ output_dtype: torch.dtype = torch.float32,
161
+ output_postfix: str = "seg",
162
+ separate_folder: bool = True,
163
+ save_output: bool = True,
164
+ ):
165
+ transforms = [
166
+ VistaPostTransformd(keys=pred_key),
167
+ Invertd(
168
+ keys=pred_key,
169
+ transform=copy.deepcopy(self.preprocessing_transforms),
170
+ orig_keys=image_key,
171
+ nearest_interp=True,
172
+ to_tensor=True,
173
+ ),
174
+ Lambdad(keys=pred_key, func=lambda x: torch.nan_to_num(x, nan=255)),
175
+ ]
176
+ if save_output:
177
+ transforms.append(
178
+ SaveImaged(
179
+ keys=pred_key,
180
+ resample=False,
181
+ output_dir=output_dir,
182
+ output_ext=output_ext,
183
+ output_dtype=output_dtype,
184
+ output_postfix=output_postfix,
185
+ separate_folder=separate_folder,
186
+ ),
187
+ )
188
+ postprocessing_transforms = Compose(transforms=transforms)
189
+ return postprocessing_transforms
190
+
191
+ def _sanitize_parameters(self, **kwargs):
192
+ """
193
+ _sanitize_parameters exists to allow users to pass any parameters whenever they wish,
194
+ be it at initialization time pipeline(...., maybe_arg=4) or at call time pipe = pipeline(...); output = pipe(...., maybe_arg=4).
195
+ The returns of _sanitize_parameters are the 3 dicts of kwargs that will be passed directly to preprocess, _forward and postprocess.
196
+ Don't fill anything if the caller didn't call with any extra parameter. That allows to keep the default arguments in the function
197
+ definition which is always more “natural”."""
198
+
199
+ vista3d_preprocessing_kwargs = {}
200
+ vista3d_infer_kwargs = {}
201
+ vista3d_postprocessing_kwargs = {}
202
+ for key in self.INFERENCE_EXTRA_ARGS:
203
+ if key in kwargs:
204
+ vista3d_infer_kwargs[key] = kwargs[key]
205
+
206
+ for key in self.PREPROCESSING_EXTRA_ARGS:
207
+ if key in kwargs:
208
+ vista3d_preprocessing_kwargs[key] = kwargs[key]
209
+
210
+ for key in self.POSTPROCESSING_EXTRA_ARGS:
211
+ if key in kwargs:
212
+ vista3d_postprocessing_kwargs[key] = kwargs[key]
213
+
214
+ return (
215
+ vista3d_preprocessing_kwargs,
216
+ vista3d_infer_kwargs,
217
+ vista3d_postprocessing_kwargs,
218
+ )
219
+
220
+ def check_prompts_format(self, label_prompt, points, point_labels):
221
+ """check the format of user prompts
222
+ label_prompt: [1,2,3,4,...,B] List of tensors
223
+ points: [[[x,y,z], [x,y,z], ...]] List of coordinates of a single object
224
+ point_labels: [[1,1,0,...]] List of scalar that matches number of points
225
+ """
226
+ # check prompt is given
227
+ if label_prompt is None and points is None:
228
+ everything_labels = self.hyper_kwargs.get("everything_labels", None)
229
+ if everything_labels is not None:
230
+ label_prompt = [torch.tensor(_) for _ in everything_labels]
231
+ return label_prompt, points, point_labels
232
+ else:
233
+ raise ValueError("Prompt must be given for inference.")
234
+ # check label_prompt
235
+ if label_prompt is not None:
236
+ if isinstance(label_prompt, list):
237
+ if not np.all([len(_) == 1 for _ in label_prompt]):
238
+ raise ValueError("Label prompt must be a list of single scalar, [1,2,3,4,...,].")
239
+ if not np.all([(x < 512).item() for x in label_prompt]):
240
+ raise ValueError("Current bundle only supports label prompt smaller than 512.")
241
+ if points is None:
242
+ supported_list = list({i + 1 for i in range(345)} - {16, 129, 130, 131, 133, 137, 138, 139, 140, 141, 142, 143, 144, 145, 162})
243
+ if not np.all([x in supported_list for x in label_prompt]):
244
+ raise ValueError("Undefined label prompt detected. Provide point prompts for zero-shot.")
245
+ else:
246
+ raise ValueError("Label prompt must be a list, [1,2,3,4,...,].")
247
+ # check points
248
+ if points is not None:
249
+ if point_labels is None:
250
+ raise ValueError("Point labels must be given if points are given.")
251
+ if not np.all([len(_) == 3 for _ in points]):
252
+ raise ValueError("Points must be three dimensional (x,y,z) in the shape of [[x,y,z],...,[x,y,z]].")
253
+ if len(points) != len(point_labels):
254
+ raise ValueError("Points must match point labels.")
255
+ if not np.all([_ in [-1, 0, 1, 2, 3] for _ in point_labels]):
256
+ raise ValueError("Point labels can only be -1,0,1 and 2,3 for special flags.")
257
+ if label_prompt is not None and points is not None:
258
+ if len(label_prompt) != 1:
259
+ raise ValueError("Label prompt can only be a single object if provided with point prompts.")
260
+ # check point_labels
261
+ if point_labels is not None:
262
+ if points is None:
263
+ raise ValueError("Points must be given if point labels are given.")
264
+ return label_prompt, points, point_labels
265
+
266
+ def transform_points(self, point, affine):
267
+ """transform point to the coordinates of the transformed image
268
+ point: numpy array [bs, N, 3]
269
+ """
270
+ bs, n = point.shape[:2]
271
+ point = np.concatenate((point, np.ones((bs, n, 1))), axis=-1)
272
+ point = rearrange(point, "b n d -> d (b n)")
273
+ point = affine @ point
274
+ point = rearrange(point, "d (b n)-> b n d", b=bs)[:, :, :3]
275
+ return point
276
+
277
+ def preprocess(
278
+ self,
279
+ inputs,
280
+ **kwargs,
281
+ ):
282
+ for key, value in kwargs.items():
283
+ if key in self._preprocess_params and value != self._preprocess_params[key]:
284
+ logging.warning(
285
+ f"Please set the parameter {key} during initialization."
286
+ )
287
+
288
+ if key not in self.PREPROCESSING_EXTRA_ARGS:
289
+ logging.warning(f"Cannot set parameter {key} for preprocessing.")
290
+
291
+ # Handle modality in input if provided
292
+ if isinstance(inputs, dict) and "modality" in inputs:
293
+ modality = look_up_option(inputs["modality"], ["CT_BODY", "MRI_BODY", "MRI_BRAIN"])
294
+ inputs["label_prompt"] = self.get_everything_labels(modality)
295
+ del inputs["modality"] # Remove modality key as it's not needed for transforms
296
+
297
+ inputs = self.preprocessing_transforms(inputs)
298
+ inputs = list_data_collate([inputs])
299
+ return inputs
300
+
301
+ def get_everything_labels(self, modality: str = 'CT_BODY'):
302
+ """Get the label set for automatic segmentation based on modality."""
303
+ if modality == "CT_BODY":
304
+ return self.EVERYTHING_LABEL_CT
305
+ elif modality == "MRI_BODY":
306
+ return self.EVERYTHING_LABEL_MRI
307
+ elif modality == "MRI_BRAIN":
308
+ return self.EVERYTHING_LABEL_BRAIN
309
+ else:
310
+ raise ValueError(f"Unsupported modality: {modality}")
311
+
312
+ def _forward(
313
+ self,
314
+ inputs,
315
+ mode: str = ForwardMode.EVAL,
316
+ amp: bool = True,
317
+ hyper_kwargs: dict = {"user_prompt": 1, "everything_labels": 1},
318
+ ):
319
+ set_determinism(seed=123)
320
+
321
+ if inputs is None:
322
+ raise ValueError("Must provide input data for inference.")
323
+
324
+ # Update everything_labels based on modality if not provided]
325
+ if "everything_labels" not in hyper_kwargs:
326
+ hyper_kwargs["everything_labels"] = self.get_everything_labels()
327
+ self.hyper_kwargs = hyper_kwargs
328
+
329
+ label_set = hyper_kwargs.get("label_set", None)
330
+ # this validation label set should be consistent with 'labels.unique()', used to generate fg/bg points
331
+ val_label_set = hyper_kwargs.get("val_label_set", label_set)
332
+ # If user provide prompts in the inference, input image must contain original affine.
333
+ # the point coordinates are from the original_affine space, while image here is after preprocess transforms.
334
+ if hyper_kwargs["user_prompt"]:
335
+ inputs, label_prompt, points, point_labels = (
336
+ inputs["image"],
337
+ inputs.get("label_prompt", None),
338
+ inputs.get("points", None),
339
+ inputs.get("point_labels", None),
340
+ )
341
+ labels = None
342
+ label_prompt, points, point_labels = self.check_prompts_format(
343
+ label_prompt, points, point_labels
344
+ )
345
+ inputs = inputs.to(self.device)
346
+ # For N foreground object, label_prompt is [1, N], but the batch number 1 needs to be removed. Convert to [N, 1]
347
+ label_prompt = (
348
+ torch.as_tensor([label_prompt]).to(inputs.device)[0].unsqueeze(-1)
349
+ if label_prompt is not None
350
+ else None
351
+ )
352
+ # For points, the size can only be [1, K, 3], where K is the number of points for this single foreground object.
353
+ if points is not None:
354
+ points = torch.as_tensor([points])
355
+ points = self.transform_points(
356
+ points,
357
+ np.linalg.inv(inputs.affine[0])
358
+ @ inputs.meta["original_affine"][0].numpy(),
359
+ )
360
+ points = torch.from_numpy(points).to(inputs.device)
361
+ point_labels = (
362
+ torch.as_tensor([point_labels]).to(inputs.device)
363
+ if point_labels is not None
364
+ else None
365
+ )
366
+
367
+ # If validation with ground truth label available.
368
+ else:
369
+ # TODO add these as attribute.
370
+ inputs, labels = inputs["image"], inputs["label"]
371
+ # create label prompt, this should be consistent with the label prompt used for training.
372
+ if label_set is None:
373
+ output_classes = hyper_kwargs.get("output_classes", None)
374
+ label_set = np.arange(output_classes).tolist()
375
+ label_prompt = torch.tensor(label_set).to(self.device).unsqueeze(-1)
376
+ # point prompt is generated withing vista3d, provide empty points
377
+ points = torch.zeros(label_prompt.shape[0], 1, 3).to(inputs.device)
378
+ point_labels = -1 + torch.zeros(label_prompt.shape[0], 1).to(inputs.device)
379
+ # validation for either auto or point.
380
+ if hyper_kwargs.get("val_head", "auto") == "auto":
381
+ # automatic only validation
382
+ # remove val_label_set, vista3d will not sample points from gt labels.
383
+ val_label_set = None
384
+ else:
385
+ # point only validation
386
+ label_prompt = None
387
+
388
+ # put iteration outputs into outputs TODO need to align with the customized inputs
389
+ outputs = {Keys.IMAGE: inputs, Keys.LABEL: labels}
390
+ mode = look_up_option(mode, ForwardMode)
391
+ if mode == ForwardMode.EVAL:
392
+ mode = eval_mode
393
+ elif mode == ForwardMode.TRAIN:
394
+ mode = train_mode
395
+ else:
396
+ raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
397
+
398
+ # execute forward computation
399
+ self.model.network.to(self.device)
400
+ with mode(self.model):
401
+ if amp:
402
+ with torch.autocast("cuda"):
403
+ outputs[Keys.PRED] = self.inferer(
404
+ inputs=inputs,
405
+ network=self.model.network,
406
+ point_coords=points,
407
+ point_labels=point_labels,
408
+ class_vector=label_prompt,
409
+ labels=labels,
410
+ label_set=val_label_set,
411
+ )
412
+ else:
413
+ outputs[Keys.PRED] = self.inferer(
414
+ inputs=inputs,
415
+ network=self.model.network,
416
+ point_coords=points,
417
+ point_labels=point_labels,
418
+ class_vector=label_prompt,
419
+ labels=labels,
420
+ label_set=val_label_set,
421
+ )
422
+ inputs = reset_ops_id(inputs)
423
+ # Add dim 0 for decollate batch
424
+ outputs["label_prompt"] = (
425
+ label_prompt.unsqueeze(0) if label_prompt is not None else None
426
+ )
427
+ outputs["points"] = points.unsqueeze(0) if points is not None else None
428
+ outputs["point_labels"] = (
429
+ point_labels.unsqueeze(0) if point_labels is not None else None
430
+ )
431
+ if torch.cuda.is_available():
432
+ torch.cuda.empty_cache()
433
+
434
+ return outputs
435
+
436
+ def postprocess(self, outputs, **kwargs):
437
+ outputs[Keys.IMAGE] = outputs[Keys.IMAGE].to(self.device)
438
+ outputs[Keys.PRED] = outputs[Keys.PRED].to(self.device)
439
+ for key, value in kwargs.items():
440
+ if key not in self.POSTPROCESSING_EXTRA_ARGS:
441
+ logging.warning(f"Cannot set parameter {key} for postprocessing.")
442
+ if (
443
+ key in self._postprocess_params
444
+ and value != self._postprocess_params[key]
445
+ ) or (key not in self._postprocess_params):
446
+ self._postprocess_params.update(kwargs)
447
+ self.postprocessing_transforms = self._init_postprocessing_transforms(
448
+ **self._postprocess_params
449
+ )
450
+
451
+ outputs = self.postprocessing_transforms(decollate_batch(outputs))
452
+ return outputs
453
+
454
+
455
+ def register_simple_pipeline():
456
+ PIPELINE_REGISTRY.register_pipeline(
457
+ "vista3d",
458
+ pipeline_class=VISTA3DPipeline,
459
+ pt_model=AutoModel,
460
+ default={"pt": (os.path.join(FILE_PATH, "vista3d_pretrained_model"), "")},
461
+ type="image", # current support type: text, audio, image, multimodal
462
+ )
vista3d_pretrained_model/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "VISTA3DModel"
4
+ ],
5
+ "encoder_embed_dim": 48,
6
+ "input_channels": 1,
7
+ "model_type": "VISTA3D",
8
+ "torch_dtype": "float32",
9
+ "transformers_version": "4.46.3"
10
+ }
vista3d_pretrained_model/model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a02845891fd747f3144ebc33c7d0dc4c79dbd0333392b8b1b50f2c99e6f0ed67
3
+ size 871971235
vista3d_pretrained_model/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65a0be47fcc84a41e46b457d1900ae584d2e5152f7d2c99ab48dc2ef0cc826c1
3
+ size 871894080