File size: 8,539 Bytes
065e508
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
import streamlit as st
import torch
import os
import tempfile
import time

# nnU-Net and visualization imports
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import pyvista as pv
import nibabel as nib
import numpy as np
from matplotlib import cm
from matplotlib.colors import ListedColormap
from stpyvista import stpyvista

# --- Caching the nnU-Net Predictor ---
# This is crucial for performance. The model is loaded once and stored in memory.
@st.cache_resource
def load_predictor(model_folder):
    """
    Loads and initializes the nnUNetPredictor.
    The @st.cache_resource decorator ensures this function is only run once.
    """
    st.write("Initializing nnU-Net predictor... (This may take a moment)")
    
    # Instantiate the predictor
    predictor = nnUNetPredictor(
        tile_step_size=0.5,
        use_gaussian=True,
        use_mirroring=True,
        perform_everything_on_device=True,
        device=torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu'),
        verbose=False,
        verbose_preprocessing=False,
        allow_tqdm=True
    )
    
    # Initialize from the trained model folder
    try:
        predictor.initialize_from_trained_model_folder(
            model_folder,
            use_folds=(0,),  # Assuming you want to use fold 0
            checkpoint_name='checkpoint_final.pth',
        )
        st.success("nnU-Net predictor initialized successfully!")
        return predictor
    except Exception as e:
        st.error(f"Failed to initialize predictor from {model_folder}. Error: {e}")
        return None

# --- Visualization Function (from your script) ---
def generate_visualization(base_image_path, mask_path):
    """
    Generates a PyVista plot of the base image and the segmentation mask.
    """
    # Load base CT scan
    img = nib.load(base_image_path)
    img_data = img.get_fdata()
    img_data = (img_data - np.min(img_data)) / np.ptp(img_data)  # Normalize 0–1

    # Load segmentation mask
    mask = nib.load(mask_path)
    mask_data = mask.get_fdata().astype(np.uint8)

    # Label dictionary (from your script)
    label_dict = {
        1: "Lower Jawbone", 2: "Upper Jawbone", 3: "Left Inferior Alveolar Canal",
        4: "Right Inferior Alveolar Canal", 5: "Left Maxillary Sinus", 6: "Right Maxillary Sinus",
        7: "Pharynx", 8: "Bridge", 9: "Crown", 10: "Implant", 11: "Upper Right Central Incisor",
        12: "Upper Right Lateral Incisor", 13: "Upper Right Canine", 14: "Upper Right First Premolar",
        15: "Upper Right Second Premolar", 16: "Upper Right First Molar", 17: "Upper Right Second Molar",
        18: "Upper Right Third Molar", 21: "Upper Left Central Incisor",
        22: "Upper Left Lateral Incisor", 23: "Upper Left Canine", 24: "Upper Left First Premolar",
        25: "Upper Left Second Premolar", 26: "Upper Left First Molar", 27: "Upper Left Second Molar",
        28: "Upper Left Third Molar", 31: "Lower Left Central Incisor",
        32: "Lower Left Lateral Incisor", 33: "Lower Left Canine", 34: "Lower Left First Premolar",
        35: "Lower Left Second Premolar", 36: "Lower Left First Molar", 37: "Lower Left Second Molar",
        38: "Lower Left Third Molar", 41: "Lower Right Central Incisor",
        42: "Lower Right Lateral Incisor", 43: "Lower Right Canine", 44: "Lower Right First Premolar",
        45: "Lower Right Second Premolar", 46: "Lower Right First Molar", 47: "Lower Right Second Molar",
        48: "Lower Right Third Molar"
    }
    
    # Generate color map
    num_labels = max(label_dict.keys()) + 1
    colors = np.vstack([
        [[0, 0, 0, 0]],
        cm.get_cmap('tab20b')(np.linspace(0, 1, 20)),
        cm.get_cmap('tab20c')(np.linspace(0, 1, 20)),
        cm.get_cmap('gist_rainbow')(np.linspace(0, 1, num_labels))
    ])[:, :4]
    colors = colors[:num_labels]
    colormap = ListedColormap(colors)

    # Wrap data in PyVista objects
    vol_img = pv.wrap(img_data)
    vol_mask = pv.wrap(mask_data)
    
    # Create plotter
    plotter = pv.Plotter(window_size=[800, 800])
    plotter.add_volume(vol_img, cmap="bone", opacity="sigmoid", name="CT Scan")
    plotter.add_volume(
        vol_mask,
        cmap=colormap,
        opacity=[0, 0.5], # Make label 0 transparent
        mapper='gpu', # Use GPU for better performance
        name="Segmentation Mask"
    )
    plotter.camera_position = 'xy'
    
    return plotter


# --- Main Streamlit App ---
def main():
    st.set_page_config(layout="wide", page_title="nnU-Net Inference App")

    st.title("🦷 nnU-Net Inference and 3D Visualization")
    st.markdown("Upload a medical image, run nnU-Net for segmentation, and visualize the results in 3D.")

    # --- Sidebar for Inputs ---
    st.sidebar.header("1. Configure Model")
    # IMPORTANT: Update this path to your default nnU-Net results folder
    default_model_path = "/path/to/your/nnUNet_results/Dataset114_ToothFairy2/nnUNetTrainer__nnUNetPlans__3d_fullres"
    model_folder = st.sidebar.text_input(
        "Enter path to trained model folder:",
        value=default_model_path
    )

    if not os.path.isdir(model_folder):
        st.sidebar.error("Model folder not found. Please provide a valid path.")
        st.stop()

    # Load the model (will be cached)
    predictor = load_predictor(model_folder)
    if predictor is None:
        st.stop()

    st.sidebar.header("2. Upload Image")
    uploaded_file = st.sidebar.file_uploader(
        "Choose a NIfTI file (.nii.gz)",
        type=['nii.gz']
    )

    # --- Main Panel for Execution and Visualization ---
    if uploaded_file is not None:
        if st.sidebar.button("✨ Run Prediction and Visualize"):
            # Use a temporary directory for safety and automatic cleanup
            with tempfile.TemporaryDirectory() as temp_dir:
                input_dir = os.path.join(temp_dir, 'input')
                output_dir = os.path.join(temp_dir, 'output')
                os.makedirs(input_dir, exist_ok=True)
                os.makedirs(output_dir, exist_ok=True)

                # Save the uploaded file to the temp input directory
                # The filename needs the _0000 suffix for nnU-Net's default file prediction
                base_name = uploaded_file.name.replace(".nii.gz", "")
                input_file_path = os.path.join(input_dir, f"{base_name}_0000.nii.gz")
                
                with open(input_file_path, "wb") as f:
                    f.write(uploaded_file.getbuffer())
                
                st.info(f"File '{uploaded_file.name}' saved to temporary location.")

                # --- Run Prediction ---
                with st.spinner("🧠 Running nnU-Net inference... This can take a while."):
                    start_time = time.time()
                    
                    # We use predict_from_files as it's the most efficient for file-based workflows
                    predictor.predict_from_files(
                        input_dir,
                        output_dir,
                        save_probabilities=False,
                        overwrite=True,
                        num_processes_preprocessing=2,
                        num_processes_segmentation_export=2
                    )
                    
                    end_time = time.time()
                    st.success(f"Inference complete! πŸŽ‰ (Time taken: {end_time - start_time:.2f} seconds)")

                # Find the output file
                output_files = os.listdir(output_dir)
                if not output_files:
                    st.error("Prediction failed. No output file was generated.")
                    st.stop()
                
                output_mask_path = os.path.join(output_dir, output_files[0])
                
                # --- Generate Visualization ---
                with st.spinner("🎨 Generating 3D visualization..."):
                    plotter = generate_visualization(input_file_path, output_mask_path)
                    stpyvista(plotter, key="pv_plot")

                # --- Provide Download Link for the Mask ---
                with open(output_mask_path, "rb") as f:
                    st.download_button(
                        label="⬇️ Download Segmentation Mask",
                        data=f,
                        file_name=f"predicted_{uploaded_file.name}",
                        mime="application/gzip"
                    )

    else:
        st.info("Please upload a file to begin.")

if __name__ == '__main__':
    main()