Image Segmentation

Improve model card: Add pipeline tag, library name, and link to Github repo

#1
by nielsr HF Staff - opened
Files changed (1) hide show
  1. README.md +203 -1
README.md CHANGED
@@ -1,3 +1,205 @@
1
  ---
2
  license: cc-by-nc-sa-4.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  license: cc-by-nc-sa-4.0
3
+ library_name: pytorch
4
+ pipeline_tag: image-segmentation
5
+ ---
6
+
7
+ <img src="imgs/nnInteractive_header_white.png">
8
+
9
+ # Python backend for `nnInteractive: Redefining 3D Promptable Segmentation`
10
+
11
+ This repository contains the nnInteractive python backend for our
12
+ [napari plugin](https://github.com/MIC-DKFZ/napari-nninteractive) and [MITK integration](Todo). It can be used for
13
+ python-based inference.
14
+
15
+ ## What is nnInteractive?
16
+
17
+ > Isensee, F.\*, Rokuss, M.\*, Krämer, L.\*, Dinkelacker, S., Ravindran, A., Stritzke, F., Hamm, B., Wald, T., Langenberg, M., Ulrich, C., Deissler, J., Floca, R., & Maier-Hein, K. (2025). nnInteractive: Redefining 3D Promptable Segmentation. https://arxiv.org/abs/2503.08373 \
18
+ > *: equal contribution
19
+
20
+ Link: [![arXiv](https://img.shields.io/badge/arXiv-2503.08373-b31b1b.svg)](https://arxiv.org/abs/2503.08373)
21
+
22
+ ##### Abstract:
23
+
24
+ Accurate and efficient 3D segmentation is essential for both clinical and research applications. While foundation
25
+ models like SAM have revolutionized interactive segmentation, their 2D design and domain shift limitations make them
26
+ ill-suited for 3D medical images. Current adaptations address some of these challenges but remain limited, either
27
+ lacking volumetric awareness, offering restricted interactivity, or supporting only a small set of structures and
28
+ modalities. Usability also remains a challenge, as current tools are rarely integrated into established imaging
29
+ platforms and often rely on cumbersome web-based interfaces with restricted functionality. We introduce nnInteractive,
30
+ the first comprehensive 3D interactive open-set segmentation method. It supports diverse prompts—including points,
31
+ scribbles, boxes, and a novel lasso prompt—while leveraging intuitive 2D interactions to generate full 3D
32
+ segmentations. Trained on 120+ diverse volumetric 3D datasets (CT, MRI, PET, 3D Microscopy, etc.), nnInteractive
33
+ sets a new state-of-the-art in accuracy, adaptability, and usability. Crucially, it is the first method integrated
34
+ into widely used image viewers (e.g., Napari, MITK), ensuring broad accessibility for real-world clinical and research
35
+ applications. Extensive benchmarking demonstrates that nnInteractive far surpasses existing methods, setting a new
36
+ standard for AI-driven interactive 3D segmentation.
37
+
38
+ <img src="imgs/figure1_method.png" width="1200">
39
+
40
+ ## Installation
41
+
42
+ ### Prerequisites
43
+
44
+ You need a Linux or Windows computer with an Nvidia GPU. 10GB of VRAM is recommended. Small objects should work with \<6GB.
45
+
46
+ ##### 1. Create a virtual environment:
47
+
48
+ nnInteractive supports Python 3.10+ and works with Conda, pip, or any other virtual environment. Here’s an example using Conda:
49
+
50
+ ```
51
+ conda create -n nnInteractive python=3.12
52
+ conda activate nnInteractive
53
+ ```
54
+
55
+ ##### 2. Install the correct PyTorch for your system
56
+
57
+ Go to the [PyTorch homepage](https://pytorch.org/get-started/locally/) and pick the right configuration.
58
+ Note that since recently PyTorch needs to be installed via pip. This is fine to do within your conda environment.
59
+
60
+ For Ubuntu with a Nvidia GPU, pick 'stable', 'Linux', 'Pip', 'Python', 'CUDA12.6' (if all drivers are up to date, otherwise use and older version):
61
+
62
+ ```
63
+ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
64
+ ```
65
+
66
+ ##### 3. Install this repository
67
+ Either install via pip:
68
+ `pip install nninteractive`
69
+
70
+ Or clone and install this repository:
71
+ ```bash
72
+ git clone https://github.com/MIC-DKFZ/nnInteractive
73
+ cd nnInteractive
74
+ pip install -e .
75
+ ```
76
+
77
+ ## Getting Started
78
+ Here is a minimalistic script that covers the core functionality of nnInteractive:
79
+
80
+ ```python
81
+ import os
82
+ import torch
83
+ import SimpleITK as sitk
84
+ from huggingface_hub import snapshot_download # Install huggingface_hub if not already installed
85
+
86
+ # --- Download Trained Model Weights (~400MB) ---
87
+ REPO_ID = "nnInteractive/nnInteractive"
88
+ MODEL_NAME = "nnInteractive_v1.0" # Updated models may be available in the future
89
+ DOWNLOAD_DIR = "/home/isensee/temp" # Specify the download directory
90
+
91
+ download_path = snapshot_download(
92
+ repo_id=REPO_ID,
93
+ allow_patterns=[f"{MODEL_NAME}/*"],
94
+ local_dir=DOWNLOAD_DIR
95
+ )
96
+
97
+ # The model is now stored in DOWNLOAD_DIR/MODEL_NAME.
98
+
99
+ # --- Initialize Inference Session ---
100
+ from nnInteractive.inference.inference_session import nnInteractiveInferenceSession
101
+
102
+ session = nnInteractiveInferenceSession(
103
+ device=torch.device("cuda:0"), # Set inference device
104
+ use_torch_compile=False, # Experimental: Not tested yet
105
+ verbose=False,
106
+ torch_n_threads=os.cpu_count(), # Use available CPU cores
107
+ do_autozoom=True, # Enables AutoZoom for better patching
108
+ use_pinned_memory=True, # Optimizes GPU memory transfers
109
+ )
110
+
111
+ # Load the trained model
112
+ model_path = os.path.join(DOWNLOAD_DIR, MODEL_NAME)
113
+ session.initialize_from_trained_model_folder(model_path)
114
+
115
+ # --- Load Input Image (Example with SimpleITK) ---
116
+ input_image = sitk.ReadImage("FILENAME")
117
+ img = sitk.GetArrayFromImage(input_image)[None] # Ensure shape (1, x, y, z)
118
+
119
+ # Validate input dimensions
120
+ if img.ndim != 4:
121
+ raise ValueError("Input image must be 4D with shape (1, x, y, z)")
122
+
123
+ session.set_image(img)
124
+
125
+ # --- Define Output Buffer ---
126
+ target_tensor = torch.zeros(img.shape[1:], dtype=torch.uint8) # Must be 3D (x, y, z)
127
+ session.set_target_buffer(target_tensor)
128
+
129
+ # --- Interacting with the Model ---
130
+ # Interactions can be freely chained and mixed in any order. Each interaction refines the segmentation.
131
+ # The model updates the segmentation mask in the target buffer after every interaction.
132
+
133
+ # Example: Add a point interaction
134
+ # POINT_COORDINATES should be a tuple (x, y, z) specifying the point location.
135
+ session.add_point_interaction(POINT_COORDINATES, include_interaction=True)
136
+
137
+ # Example: Add a bounding box interaction
138
+ # BBOX_COORDINATES must be specified as [[x1, x2], [y1, y2], [z1, z2]] (half-open intervals).
139
+ # Note: nnInteractive pre-trained models currently only support **2D bounding boxes**.
140
+ # This means that **one dimension must be [d, d+1]** to indicate a single slice.
141
+
142
+ # Example of a 2D bounding box in the axial plane (XY slice at depth Z)
143
+ # BBOX_COORDINATES = [[30, 80], [40, 100], [10, 11]] # X: 30-80, Y: 40-100, Z: slice 10
144
+
145
+ session.add_bbox_interaction(BBOX_COORDINATES, include_interaction=True)
146
+
147
+ # Example: Add a scribble interaction
148
+ # - A 3D image of the same shape as img where one slice (any axis-aligned orientation) contains a hand-drawn scribble.
149
+ # - Background must be 0, and scribble must be 1.
150
+ # - Use session.preferred_scribble_thickness for optimal results.
151
+ session.add_scribble_interaction(SCRIBBLE_IMAGE, include_interaction=True)
152
+
153
+ # Example: Add a lasso interaction
154
+ # - Similarly to scribble a 3D image with a single slice containing a **closed contour** representing the selection.
155
+ session.add_lasso_interaction(LASSO_IMAGE, include_interaction=True)
156
+
157
+ # You can combine any number of interactions as needed.
158
+ # The model refines the segmentation result incrementally with each new interaction.
159
+
160
+ # --- Retrieve Results ---
161
+ # The target buffer holds the segmentation result.
162
+ results = session.target_buffer.clone()
163
+ # OR (equivalent)
164
+ results = target_tensor.clone()
165
+
166
+ # Cloning is required because the buffer will be **reused** for the next object.
167
+ # Alternatively, set a new target buffer for each object:
168
+ session.set_target_buffer(torch.zeros(img.shape[1:], dtype=torch.uint8))
169
+
170
+ # --- Start a New Object Segmentation ---
171
+ session.reset_interactions() # Clears the target buffer and resets interactions
172
+
173
+ # Now you can start segmenting the next object in the image.
174
+
175
+ # --- Set a New Image ---
176
+ # Setting a new image also requires setting a new matching target buffer
177
+ session.set_image(NEW_IMAGE)
178
+ session.set_target_buffer(torch.zeros(NEW_IMAGE.shape[1:], dtype=torch.uint8))
179
+
180
+ # Enjoy!
181
+ ```
182
+
183
+ ## Citation
184
+ When using nnInteractive, please cite the following paper:
185
+
186
+ > Isensee, F.\*, Rokuss, M.\*, Krämer, L.\*, Dinkelacker, S., Ravindran, A., Stritzke, F., Hamm, B., Wald, T., Langenberg, M., Ulrich, C., Deissler, J., Floca, R., & Maier-Hein, K. (2025). nnInteractive: Redefining 3D Promptable Segmentation. https://arxiv.org/abs/2503.08373 \
187
+ > *: equal contribution
188
+
189
+ Link: [![arXiv](https://img.shields.io/badge/arXiv-2503.08373-b31b1b.svg)](https://arxiv.org/abs/2503.08373)
190
+
191
+ # License
192
+ Note that while this repository is available under Apache-2.0 license (see [LICENSE](./LICENSE)), the [model checkpoint](https://huggingface.co/nnInteractive/nnInteractive) is `Creative Commons Attribution Non Commercial Share Alike 4.0`!
193
+
194
+ ## Acknowledgments
195
+
196
+ <p align="left">
197
+ <img src="imgs/Logos/HI_Logo.png" width="150"> &nbsp;&nbsp;&nbsp;&nbsp;
198
+ <img src="imgs/Logos/DKFZ_Logo.png" width="500">
199
+ </p>
200
+
201
+ This repository is developed and maintained by the Applied Computer Vision Lab (ACVL)
202
+ of [Helmholtz Imaging](https://www.helmholtz-imaging.de/) and the
203
+ [Division of Medical Image Computing](https://www.dkfz.de/en/medical-image-computing) at DKFZ.
204
+
205
+ Github repo: https://github.com/MIC-DKFZ/nnInteractive