BRATS / app.py
Hedi-Bk's picture
state_dict
a3523ff
import io
import numpy as np
import nibabel as nib
import torch
import gradio as gr
import matplotlib.pyplot as plt
from PIL import Image
from monai.transforms import (
Compose, NormalizeIntensityd,
ResizeWithPadOrCropd, ToTensord
)
from monai.networks.nets import SwinUNETR
model = SwinUNETR(img_size=(128, 128, 128),
in_channels=4,
out_channels=3,
feature_size=48,
use_checkpoint=True)
model.load_state_dict(torch.load("last_state_dict .pth",map_location=torch.device('cpu')))
model.eval()
val_transform = Compose([
NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True),
ResizeWithPadOrCropd(keys="image", spatial_size=(128, 128, 128)),
ToTensord(keys="image")
])
# This will be set in predict_and_store
cached_images = []
cached_masks = []
# Helper to generate an overlay image for one slice
def get_overlay_figure(image, mask, slice_index):
img_slice = image[:, :, slice_index]
img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min())
img_rgb = np.stack([img_slice]*3, axis=-1)
mask_rgb = np.zeros_like(img_rgb)
mask_rgb[mask[:, :, slice_index] == 2] = [0, 0, 1] # Blue: WT
mask_rgb[mask[:, :, slice_index] == 1] = [0, 1, 0] # Green: TC
mask_rgb[mask[:, :, slice_index] == 4] = [1, 0, 0] # Red: ET
overlay = np.clip((1 - 0.4) * img_rgb + 0.4 * mask_rgb, 0, 1)
fig, ax = plt.subplots(figsize=(6, 6))
ax.imshow(overlay)
ax.axis("off")
ax.set_title(f"Segmentation Overlay - Slice {slice_index}")
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close(fig)
return Image.open(buf)
# Prediction function that stores the processed image and mask for slider use
def predict_and_store(flair_file, t1_file, t1ce_file, t2_file):
try:
global cached_images, cached_masks
flair = nib.load(flair_file.name).get_fdata()
t1 = nib.load(t1_file.name).get_fdata()
t1ce = nib.load(t1ce_file.name).get_fdata()
t2 = nib.load(t2_file.name).get_fdata()
image = np.stack([flair, t1, t1ce, t2], axis=0)
data = {"image": image}
data = val_transform(data)
image_tensor = data["image"].unsqueeze(0)
with torch.no_grad():
output = model(image_tensor)
prob = torch.sigmoid(output)
seg = prob[0].detach().cpu().numpy()
seg = (seg > 0.5).astype(np.int8)
seg_out = np.zeros((seg.shape[1], seg.shape[2], seg.shape[3]))
seg_out[seg[1] == 1] = 2
seg_out[seg[0] == 1] = 1
seg_out[seg[2] == 1] = 4
cached_images = image_tensor.cpu().numpy()[0, 0]
cached_masks = seg_out
return f"Segmentation done. Use the slider to browse slices."
except Exception as e:
return f"❌Error: {str(e)}"
# Function to get the overlay for a specific slice
def get_slice_overlay(slice_index):
global cached_images, cached_masks
if len(cached_images) == 0:
return None
return get_overlay_figure(cached_images, cached_masks, slice_index)
with gr.Blocks() as iface:
gr.Markdown("### 🧠 SwinUNETR Brain Tumor Segmentation Viewer")
gr.Markdown("#### 📂 Upload MRI Modalities")
with gr.Row():
flair = gr.File(file_types=[".nii", ".nii.gz"], label="FLAIR (.nii)")
t1 = gr.File(file_types=[".nii", ".nii.gz"], label="T1 (.nii)")
t1ce = gr.File(file_types=[".nii", ".nii.gz"], label="T1ce (.nii)")
t2 = gr.File(file_types=[".nii", ".nii.gz"], label="T2 (.nii)")
# Ajout de la section exemples
gr.Markdown("#### 🧪 Or use example data")
examples = gr.Examples(
examples=[
["examples/flair.nii", "examples/t1.nii", "examples/t1ce.nii", "examples/t2.nii"]
],
inputs=[flair, t1, t1ce, t2],
label="Example MRI files"
)
output_msg = gr.Textbox(label="📣 Status")
run_button = gr.Button("▶️ Run Segmentation")
run_button.click(fn=predict_and_store,
inputs=[flair, t1, t1ce, t2],
outputs=output_msg)
gr.Markdown("### 🖼️ View Slices with Overlay")
slice_slider = gr.Slider(minimum=0, maximum=127, value=64, step=1, label="Slice Index")
slice_image = gr.Image(type="pil", label="Segmentation Overlay")
slice_slider.change(fn=get_slice_overlay,
inputs=slice_slider,
outputs=slice_image)
iface.launch(share=True)