File size: 8,539 Bytes
065e508 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import streamlit as st
import torch
import os
import tempfile
import time
# nnU-Net and visualization imports
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
import pyvista as pv
import nibabel as nib
import numpy as np
from matplotlib import cm
from matplotlib.colors import ListedColormap
from stpyvista import stpyvista
# --- Caching the nnU-Net Predictor ---
# This is crucial for performance. The model is loaded once and stored in memory.
@st.cache_resource
def load_predictor(model_folder):
"""
Loads and initializes the nnUNetPredictor.
The @st.cache_resource decorator ensures this function is only run once.
"""
st.write("Initializing nnU-Net predictor... (This may take a moment)")
# Instantiate the predictor
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_device=True,
device=torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu'),
verbose=False,
verbose_preprocessing=False,
allow_tqdm=True
)
# Initialize from the trained model folder
try:
predictor.initialize_from_trained_model_folder(
model_folder,
use_folds=(0,), # Assuming you want to use fold 0
checkpoint_name='checkpoint_final.pth',
)
st.success("nnU-Net predictor initialized successfully!")
return predictor
except Exception as e:
st.error(f"Failed to initialize predictor from {model_folder}. Error: {e}")
return None
# --- Visualization Function (from your script) ---
def generate_visualization(base_image_path, mask_path):
"""
Generates a PyVista plot of the base image and the segmentation mask.
"""
# Load base CT scan
img = nib.load(base_image_path)
img_data = img.get_fdata()
img_data = (img_data - np.min(img_data)) / np.ptp(img_data) # Normalize 0β1
# Load segmentation mask
mask = nib.load(mask_path)
mask_data = mask.get_fdata().astype(np.uint8)
# Label dictionary (from your script)
label_dict = {
1: "Lower Jawbone", 2: "Upper Jawbone", 3: "Left Inferior Alveolar Canal",
4: "Right Inferior Alveolar Canal", 5: "Left Maxillary Sinus", 6: "Right Maxillary Sinus",
7: "Pharynx", 8: "Bridge", 9: "Crown", 10: "Implant", 11: "Upper Right Central Incisor",
12: "Upper Right Lateral Incisor", 13: "Upper Right Canine", 14: "Upper Right First Premolar",
15: "Upper Right Second Premolar", 16: "Upper Right First Molar", 17: "Upper Right Second Molar",
18: "Upper Right Third Molar", 21: "Upper Left Central Incisor",
22: "Upper Left Lateral Incisor", 23: "Upper Left Canine", 24: "Upper Left First Premolar",
25: "Upper Left Second Premolar", 26: "Upper Left First Molar", 27: "Upper Left Second Molar",
28: "Upper Left Third Molar", 31: "Lower Left Central Incisor",
32: "Lower Left Lateral Incisor", 33: "Lower Left Canine", 34: "Lower Left First Premolar",
35: "Lower Left Second Premolar", 36: "Lower Left First Molar", 37: "Lower Left Second Molar",
38: "Lower Left Third Molar", 41: "Lower Right Central Incisor",
42: "Lower Right Lateral Incisor", 43: "Lower Right Canine", 44: "Lower Right First Premolar",
45: "Lower Right Second Premolar", 46: "Lower Right First Molar", 47: "Lower Right Second Molar",
48: "Lower Right Third Molar"
}
# Generate color map
num_labels = max(label_dict.keys()) + 1
colors = np.vstack([
[[0, 0, 0, 0]],
cm.get_cmap('tab20b')(np.linspace(0, 1, 20)),
cm.get_cmap('tab20c')(np.linspace(0, 1, 20)),
cm.get_cmap('gist_rainbow')(np.linspace(0, 1, num_labels))
])[:, :4]
colors = colors[:num_labels]
colormap = ListedColormap(colors)
# Wrap data in PyVista objects
vol_img = pv.wrap(img_data)
vol_mask = pv.wrap(mask_data)
# Create plotter
plotter = pv.Plotter(window_size=[800, 800])
plotter.add_volume(vol_img, cmap="bone", opacity="sigmoid", name="CT Scan")
plotter.add_volume(
vol_mask,
cmap=colormap,
opacity=[0, 0.5], # Make label 0 transparent
mapper='gpu', # Use GPU for better performance
name="Segmentation Mask"
)
plotter.camera_position = 'xy'
return plotter
# --- Main Streamlit App ---
def main():
st.set_page_config(layout="wide", page_title="nnU-Net Inference App")
st.title("π¦· nnU-Net Inference and 3D Visualization")
st.markdown("Upload a medical image, run nnU-Net for segmentation, and visualize the results in 3D.")
# --- Sidebar for Inputs ---
st.sidebar.header("1. Configure Model")
# IMPORTANT: Update this path to your default nnU-Net results folder
default_model_path = "/path/to/your/nnUNet_results/Dataset114_ToothFairy2/nnUNetTrainer__nnUNetPlans__3d_fullres"
model_folder = st.sidebar.text_input(
"Enter path to trained model folder:",
value=default_model_path
)
if not os.path.isdir(model_folder):
st.sidebar.error("Model folder not found. Please provide a valid path.")
st.stop()
# Load the model (will be cached)
predictor = load_predictor(model_folder)
if predictor is None:
st.stop()
st.sidebar.header("2. Upload Image")
uploaded_file = st.sidebar.file_uploader(
"Choose a NIfTI file (.nii.gz)",
type=['nii.gz']
)
# --- Main Panel for Execution and Visualization ---
if uploaded_file is not None:
if st.sidebar.button("β¨ Run Prediction and Visualize"):
# Use a temporary directory for safety and automatic cleanup
with tempfile.TemporaryDirectory() as temp_dir:
input_dir = os.path.join(temp_dir, 'input')
output_dir = os.path.join(temp_dir, 'output')
os.makedirs(input_dir, exist_ok=True)
os.makedirs(output_dir, exist_ok=True)
# Save the uploaded file to the temp input directory
# The filename needs the _0000 suffix for nnU-Net's default file prediction
base_name = uploaded_file.name.replace(".nii.gz", "")
input_file_path = os.path.join(input_dir, f"{base_name}_0000.nii.gz")
with open(input_file_path, "wb") as f:
f.write(uploaded_file.getbuffer())
st.info(f"File '{uploaded_file.name}' saved to temporary location.")
# --- Run Prediction ---
with st.spinner("π§ Running nnU-Net inference... This can take a while."):
start_time = time.time()
# We use predict_from_files as it's the most efficient for file-based workflows
predictor.predict_from_files(
input_dir,
output_dir,
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=2,
num_processes_segmentation_export=2
)
end_time = time.time()
st.success(f"Inference complete! π (Time taken: {end_time - start_time:.2f} seconds)")
# Find the output file
output_files = os.listdir(output_dir)
if not output_files:
st.error("Prediction failed. No output file was generated.")
st.stop()
output_mask_path = os.path.join(output_dir, output_files[0])
# --- Generate Visualization ---
with st.spinner("π¨ Generating 3D visualization..."):
plotter = generate_visualization(input_file_path, output_mask_path)
stpyvista(plotter, key="pv_plot")
# --- Provide Download Link for the Mask ---
with open(output_mask_path, "rb") as f:
st.download_button(
label="β¬οΈ Download Segmentation Mask",
data=f,
file_name=f"predicted_{uploaded_file.name}",
mime="application/gzip"
)
else:
st.info("Please upload a file to begin.")
if __name__ == '__main__':
main() |