File size: 3,592 Bytes
9866a07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import tempfile
import nibabel as nib
import numpy as np
import gdown
import gradio as gr
import torch

# nnU-Net imports
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

# -------------------------------
# STEP 1 β€” Download model if needed
# -------------------------------
MODEL_DIR = "models/nnUNet_trained"
os.makedirs(MODEL_DIR, exist_ok=True)

FOLDER_URL = "https://drive.google.com/drive/folders/163zOL8NmdYqhRCAGNOWG7Ak85_awF4Xv"

if not os.path.exists(os.path.join(MODEL_DIR, "fold_0")):
    print("⏬ Downloading nnU-Net model from Google Drive ...")
    gdown.download_folder(FOLDER_URL, output=MODEL_DIR, quiet=False, use_cookies=False)
    print("βœ… Model downloaded.")
else:
    print("βœ… Model already cached, skipping download.")

# -------------------------------
# STEP 2 β€” Initialize predictor
# -------------------------------
print("πŸ”§ Initializing nnU-Net predictor ...")
predictor = nnUNetPredictor(
    model_training_output_dir=MODEL_DIR,
    use_folds=[0, 1, 2, 3, 4],  # ensemble of 5 folds
    checkpoint_name="checkpoint_final.pth",
    device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
)
predictor.initialize_from_trained_model_folder()
print("βœ… Predictor ready.")

# -------------------------------
# STEP 3 β€” Define inference function
# -------------------------------
def run_inference(nii_file):
    # Save uploaded file to a temporary path
    with tempfile.TemporaryDirectory() as tmpdir:
        input_path = os.path.join(tmpdir, "input.nii.gz")
        nii_file.save(input_path)

        # Output directory
        output_dir = os.path.join(tmpdir, "output")
        os.makedirs(output_dir, exist_ok=True)

        # Run prediction
        predictor.predict_from_files(
            [input_path],
            output_dir=output_dir,
            save_probabilities=False
        )

        # Get the result file
        output_files = [f for f in os.listdir(output_dir) if f.endswith(".nii.gz")]
        if not output_files:
            raise RuntimeError("Prediction failed: No output NIfTI file found.")

        result_path = os.path.join(output_dir, output_files[0])

        # Load and overlay for preview
        img = nib.load(input_path).get_fdata()
        seg = nib.load(result_path).get_fdata()

        # Simple 2D preview (middle slice)
        mid_slice = img.shape[2] // 2
        overlay = np.stack([
            img[:, :, mid_slice] / np.max(img),             # base grayscale
            seg[:, :, mid_slice] / np.max(seg) if np.max(seg) > 0 else np.zeros_like(seg[:, :, mid_slice]),
            np.zeros_like(seg[:, :, mid_slice])
        ], axis=-1)

        return (overlay, result_path)

# -------------------------------
# STEP 4 β€” Gradio Interface
# -------------------------------
title = "🧠 Physics-Informed nnU-Net (Fisher PDE) for Glioblastoma MRI Segmentation"
description = """

Upload a **.nii** or **.nii.gz** MRI volume to run inference using the Physics-Informed nnU-Net model.

The model integrates a Fisher diffusion-reaction equation for enhanced tumor boundary accuracy.

"""

iface = gr.Interface(
    fn=run_inference,
    inputs=gr.File(label="Upload MRI NIfTI (.nii / .nii.gz)"),
    outputs=[
        gr.Image(label="Segmentation Overlay (mid-slice)"),
        gr.File(label="Download Full 3D Segmentation (.nii.gz)")
    ],
    title=title,
    description=description,
    allow_flagging="never"
)

if __name__ == "__main__":
    iface.launch()