import streamlit as st import torch import os import tempfile import time # nnU-Net and visualization imports from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor import pyvista as pv import nibabel as nib import numpy as np from matplotlib import cm from matplotlib.colors import ListedColormap from stpyvista import stpyvista # --- Caching the nnU-Net Predictor --- # This is crucial for performance. The model is loaded once and stored in memory. @st.cache_resource def load_predictor(model_folder): """ Loads and initializes the nnUNetPredictor. The @st.cache_resource decorator ensures this function is only run once. """ st.write("Initializing nnU-Net predictor... (This may take a moment)") # Instantiate the predictor predictor = nnUNetPredictor( tile_step_size=0.5, use_gaussian=True, use_mirroring=True, perform_everything_on_device=True, device=torch.device('cuda', 0) if torch.cuda.is_available() else torch.device('cpu'), verbose=False, verbose_preprocessing=False, allow_tqdm=True ) # Initialize from the trained model folder try: predictor.initialize_from_trained_model_folder( model_folder, use_folds=(0,), # Assuming you want to use fold 0 checkpoint_name='checkpoint_final.pth', ) st.success("nnU-Net predictor initialized successfully!") return predictor except Exception as e: st.error(f"Failed to initialize predictor from {model_folder}. Error: {e}") return None # --- Visualization Function (from your script) --- def generate_visualization(base_image_path, mask_path): """ Generates a PyVista plot of the base image and the segmentation mask. """ # Load base CT scan img = nib.load(base_image_path) img_data = img.get_fdata() img_data = (img_data - np.min(img_data)) / np.ptp(img_data) # Normalize 0–1 # Load segmentation mask mask = nib.load(mask_path) mask_data = mask.get_fdata().astype(np.uint8) # Label dictionary (from your script) label_dict = { 1: "Lower Jawbone", 2: "Upper Jawbone", 3: "Left Inferior Alveolar Canal", 4: "Right Inferior Alveolar Canal", 5: "Left Maxillary Sinus", 6: "Right Maxillary Sinus", 7: "Pharynx", 8: "Bridge", 9: "Crown", 10: "Implant", 11: "Upper Right Central Incisor", 12: "Upper Right Lateral Incisor", 13: "Upper Right Canine", 14: "Upper Right First Premolar", 15: "Upper Right Second Premolar", 16: "Upper Right First Molar", 17: "Upper Right Second Molar", 18: "Upper Right Third Molar", 21: "Upper Left Central Incisor", 22: "Upper Left Lateral Incisor", 23: "Upper Left Canine", 24: "Upper Left First Premolar", 25: "Upper Left Second Premolar", 26: "Upper Left First Molar", 27: "Upper Left Second Molar", 28: "Upper Left Third Molar", 31: "Lower Left Central Incisor", 32: "Lower Left Lateral Incisor", 33: "Lower Left Canine", 34: "Lower Left First Premolar", 35: "Lower Left Second Premolar", 36: "Lower Left First Molar", 37: "Lower Left Second Molar", 38: "Lower Left Third Molar", 41: "Lower Right Central Incisor", 42: "Lower Right Lateral Incisor", 43: "Lower Right Canine", 44: "Lower Right First Premolar", 45: "Lower Right Second Premolar", 46: "Lower Right First Molar", 47: "Lower Right Second Molar", 48: "Lower Right Third Molar" } # Generate color map num_labels = max(label_dict.keys()) + 1 colors = np.vstack([ [[0, 0, 0, 0]], cm.get_cmap('tab20b')(np.linspace(0, 1, 20)), cm.get_cmap('tab20c')(np.linspace(0, 1, 20)), cm.get_cmap('gist_rainbow')(np.linspace(0, 1, num_labels)) ])[:, :4] colors = colors[:num_labels] colormap = ListedColormap(colors) # Wrap data in PyVista objects vol_img = pv.wrap(img_data) vol_mask = pv.wrap(mask_data) # Create plotter plotter = pv.Plotter(window_size=[800, 800]) plotter.add_volume(vol_img, cmap="bone", opacity="sigmoid", name="CT Scan") plotter.add_volume( vol_mask, cmap=colormap, opacity=[0, 0.5], # Make label 0 transparent mapper='gpu', # Use GPU for better performance name="Segmentation Mask" ) plotter.camera_position = 'xy' return plotter # --- Main Streamlit App --- def main(): st.set_page_config(layout="wide", page_title="nnU-Net Inference App") st.title("🦷 nnU-Net Inference and 3D Visualization") st.markdown("Upload a medical image, run nnU-Net for segmentation, and visualize the results in 3D.") # --- Sidebar for Inputs --- st.sidebar.header("1. Configure Model") # IMPORTANT: Update this path to your default nnU-Net results folder default_model_path = "/path/to/your/nnUNet_results/Dataset114_ToothFairy2/nnUNetTrainer__nnUNetPlans__3d_fullres" model_folder = st.sidebar.text_input( "Enter path to trained model folder:", value=default_model_path ) if not os.path.isdir(model_folder): st.sidebar.error("Model folder not found. Please provide a valid path.") st.stop() # Load the model (will be cached) predictor = load_predictor(model_folder) if predictor is None: st.stop() st.sidebar.header("2. Upload Image") uploaded_file = st.sidebar.file_uploader( "Choose a NIfTI file (.nii.gz)", type=['nii.gz'] ) # --- Main Panel for Execution and Visualization --- if uploaded_file is not None: if st.sidebar.button("✨ Run Prediction and Visualize"): # Use a temporary directory for safety and automatic cleanup with tempfile.TemporaryDirectory() as temp_dir: input_dir = os.path.join(temp_dir, 'input') output_dir = os.path.join(temp_dir, 'output') os.makedirs(input_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True) # Save the uploaded file to the temp input directory # The filename needs the _0000 suffix for nnU-Net's default file prediction base_name = uploaded_file.name.replace(".nii.gz", "") input_file_path = os.path.join(input_dir, f"{base_name}_0000.nii.gz") with open(input_file_path, "wb") as f: f.write(uploaded_file.getbuffer()) st.info(f"File '{uploaded_file.name}' saved to temporary location.") # --- Run Prediction --- with st.spinner("🧠 Running nnU-Net inference... This can take a while."): start_time = time.time() # We use predict_from_files as it's the most efficient for file-based workflows predictor.predict_from_files( input_dir, output_dir, save_probabilities=False, overwrite=True, num_processes_preprocessing=2, num_processes_segmentation_export=2 ) end_time = time.time() st.success(f"Inference complete! 🎉 (Time taken: {end_time - start_time:.2f} seconds)") # Find the output file output_files = os.listdir(output_dir) if not output_files: st.error("Prediction failed. No output file was generated.") st.stop() output_mask_path = os.path.join(output_dir, output_files[0]) # --- Generate Visualization --- with st.spinner("🎨 Generating 3D visualization..."): plotter = generate_visualization(input_file_path, output_mask_path) stpyvista(plotter, key="pv_plot") # --- Provide Download Link for the Mask --- with open(output_mask_path, "rb") as f: st.download_button( label="⬇️ Download Segmentation Mask", data=f, file_name=f"predicted_{uploaded_file.name}", mime="application/gzip" ) else: st.info("Please upload a file to begin.") if __name__ == '__main__': main()