README / app.py
medimaging's picture
Upload app.py
9866a07 verified
import os
import tempfile
import nibabel as nib
import numpy as np
import gdown
import gradio as gr
import torch
# nnU-Net imports
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
# -------------------------------
# STEP 1 β€” Download model if needed
# -------------------------------
MODEL_DIR = "models/nnUNet_trained"
os.makedirs(MODEL_DIR, exist_ok=True)
FOLDER_URL = "https://drive.google.com/drive/folders/163zOL8NmdYqhRCAGNOWG7Ak85_awF4Xv"
if not os.path.exists(os.path.join(MODEL_DIR, "fold_0")):
print("⏬ Downloading nnU-Net model from Google Drive ...")
gdown.download_folder(FOLDER_URL, output=MODEL_DIR, quiet=False, use_cookies=False)
print("βœ… Model downloaded.")
else:
print("βœ… Model already cached, skipping download.")
# -------------------------------
# STEP 2 β€” Initialize predictor
# -------------------------------
print("πŸ”§ Initializing nnU-Net predictor ...")
predictor = nnUNetPredictor(
model_training_output_dir=MODEL_DIR,
use_folds=[0, 1, 2, 3, 4], # ensemble of 5 folds
checkpoint_name="checkpoint_final.pth",
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
predictor.initialize_from_trained_model_folder()
print("βœ… Predictor ready.")
# -------------------------------
# STEP 3 β€” Define inference function
# -------------------------------
def run_inference(nii_file):
# Save uploaded file to a temporary path
with tempfile.TemporaryDirectory() as tmpdir:
input_path = os.path.join(tmpdir, "input.nii.gz")
nii_file.save(input_path)
# Output directory
output_dir = os.path.join(tmpdir, "output")
os.makedirs(output_dir, exist_ok=True)
# Run prediction
predictor.predict_from_files(
[input_path],
output_dir=output_dir,
save_probabilities=False
)
# Get the result file
output_files = [f for f in os.listdir(output_dir) if f.endswith(".nii.gz")]
if not output_files:
raise RuntimeError("Prediction failed: No output NIfTI file found.")
result_path = os.path.join(output_dir, output_files[0])
# Load and overlay for preview
img = nib.load(input_path).get_fdata()
seg = nib.load(result_path).get_fdata()
# Simple 2D preview (middle slice)
mid_slice = img.shape[2] // 2
overlay = np.stack([
img[:, :, mid_slice] / np.max(img), # base grayscale
seg[:, :, mid_slice] / np.max(seg) if np.max(seg) > 0 else np.zeros_like(seg[:, :, mid_slice]),
np.zeros_like(seg[:, :, mid_slice])
], axis=-1)
return (overlay, result_path)
# -------------------------------
# STEP 4 β€” Gradio Interface
# -------------------------------
title = "🧠 Physics-Informed nnU-Net (Fisher PDE) for Glioblastoma MRI Segmentation"
description = """
Upload a **.nii** or **.nii.gz** MRI volume to run inference using the Physics-Informed nnU-Net model.
The model integrates a Fisher diffusion-reaction equation for enhanced tumor boundary accuracy.
"""
iface = gr.Interface(
fn=run_inference,
inputs=gr.File(label="Upload MRI NIfTI (.nii / .nii.gz)"),
outputs=[
gr.Image(label="Segmentation Overlay (mid-slice)"),
gr.File(label="Download Full 3D Segmentation (.nii.gz)")
],
title=title,
description=description,
allow_flagging="never"
)
if __name__ == "__main__":
iface.launch()