import os import tempfile import nibabel as nib import numpy as np import gdown import gradio as gr import torch # nnU-Net imports from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor # ------------------------------- # STEP 1 — Download model if needed # ------------------------------- MODEL_DIR = "models/nnUNet_trained" os.makedirs(MODEL_DIR, exist_ok=True) FOLDER_URL = "https://drive.google.com/drive/folders/163zOL8NmdYqhRCAGNOWG7Ak85_awF4Xv" if not os.path.exists(os.path.join(MODEL_DIR, "fold_0")): print("⏬ Downloading nnU-Net model from Google Drive ...") gdown.download_folder(FOLDER_URL, output=MODEL_DIR, quiet=False, use_cookies=False) print("✅ Model downloaded.") else: print("✅ Model already cached, skipping download.") # ------------------------------- # STEP 2 — Initialize predictor # ------------------------------- print("🔧 Initializing nnU-Net predictor ...") predictor = nnUNetPredictor( model_training_output_dir=MODEL_DIR, use_folds=[0, 1, 2, 3, 4], # ensemble of 5 folds checkpoint_name="checkpoint_final.pth", device=torch.device("cuda" if torch.cuda.is_available() else "cpu") ) predictor.initialize_from_trained_model_folder() print("✅ Predictor ready.") # ------------------------------- # STEP 3 — Define inference function # ------------------------------- def run_inference(nii_file): # Save uploaded file to a temporary path with tempfile.TemporaryDirectory() as tmpdir: input_path = os.path.join(tmpdir, "input.nii.gz") nii_file.save(input_path) # Output directory output_dir = os.path.join(tmpdir, "output") os.makedirs(output_dir, exist_ok=True) # Run prediction predictor.predict_from_files( [input_path], output_dir=output_dir, save_probabilities=False ) # Get the result file output_files = [f for f in os.listdir(output_dir) if f.endswith(".nii.gz")] if not output_files: raise RuntimeError("Prediction failed: No output NIfTI file found.") result_path = os.path.join(output_dir, output_files[0]) # Load and overlay for preview img = nib.load(input_path).get_fdata() seg = nib.load(result_path).get_fdata() # Simple 2D preview (middle slice) mid_slice = img.shape[2] // 2 overlay = np.stack([ img[:, :, mid_slice] / np.max(img), # base grayscale seg[:, :, mid_slice] / np.max(seg) if np.max(seg) > 0 else np.zeros_like(seg[:, :, mid_slice]), np.zeros_like(seg[:, :, mid_slice]) ], axis=-1) return (overlay, result_path) # ------------------------------- # STEP 4 — Gradio Interface # ------------------------------- title = "🧠 Physics-Informed nnU-Net (Fisher PDE) for Glioblastoma MRI Segmentation" description = """ Upload a **.nii** or **.nii.gz** MRI volume to run inference using the Physics-Informed nnU-Net model. The model integrates a Fisher diffusion-reaction equation for enhanced tumor boundary accuracy. """ iface = gr.Interface( fn=run_inference, inputs=gr.File(label="Upload MRI NIfTI (.nii / .nii.gz)"), outputs=[ gr.Image(label="Segmentation Overlay (mid-slice)"), gr.File(label="Download Full 3D Segmentation (.nii.gz)") ], title=title, description=description, allow_flagging="never" ) if __name__ == "__main__": iface.launch()