| import io |
| import numpy as np |
| import nibabel as nib |
| import torch |
| import gradio as gr |
| import matplotlib.pyplot as plt |
| from PIL import Image |
| from monai.transforms import ( |
| Compose, NormalizeIntensityd, |
| ResizeWithPadOrCropd, ToTensord |
| ) |
| from monai.networks.nets import SwinUNETR |
|
|
|
|
| model = SwinUNETR(img_size=(128, 128, 128), |
| in_channels=4, |
| out_channels=3, |
| feature_size=48, |
| use_checkpoint=True) |
|
|
| model.load_state_dict(torch.load("last_state_dict .pth",map_location=torch.device('cpu'))) |
|
|
| model.eval() |
| val_transform = Compose([ |
| NormalizeIntensityd(keys="image", nonzero=True, channel_wise=True), |
| ResizeWithPadOrCropd(keys="image", spatial_size=(128, 128, 128)), |
| ToTensord(keys="image") |
| ]) |
|
|
| |
| cached_images = [] |
| cached_masks = [] |
|
|
| |
| def get_overlay_figure(image, mask, slice_index): |
| img_slice = image[:, :, slice_index] |
| img_slice = (img_slice - img_slice.min()) / (img_slice.max() - img_slice.min()) |
| img_rgb = np.stack([img_slice]*3, axis=-1) |
|
|
| mask_rgb = np.zeros_like(img_rgb) |
| mask_rgb[mask[:, :, slice_index] == 2] = [0, 0, 1] |
| mask_rgb[mask[:, :, slice_index] == 1] = [0, 1, 0] |
| mask_rgb[mask[:, :, slice_index] == 4] = [1, 0, 0] |
|
|
| overlay = np.clip((1 - 0.4) * img_rgb + 0.4 * mask_rgb, 0, 1) |
|
|
| fig, ax = plt.subplots(figsize=(6, 6)) |
| ax.imshow(overlay) |
| ax.axis("off") |
| ax.set_title(f"Segmentation Overlay - Slice {slice_index}") |
| buf = io.BytesIO() |
| plt.savefig(buf, format='png') |
| buf.seek(0) |
| plt.close(fig) |
| return Image.open(buf) |
|
|
| |
| def predict_and_store(flair_file, t1_file, t1ce_file, t2_file): |
| try: |
| global cached_images, cached_masks |
|
|
| flair = nib.load(flair_file.name).get_fdata() |
| t1 = nib.load(t1_file.name).get_fdata() |
| t1ce = nib.load(t1ce_file.name).get_fdata() |
| t2 = nib.load(t2_file.name).get_fdata() |
|
|
| image = np.stack([flair, t1, t1ce, t2], axis=0) |
| data = {"image": image} |
| data = val_transform(data) |
| image_tensor = data["image"].unsqueeze(0) |
|
|
| with torch.no_grad(): |
| output = model(image_tensor) |
| prob = torch.sigmoid(output) |
| seg = prob[0].detach().cpu().numpy() |
| seg = (seg > 0.5).astype(np.int8) |
| seg_out = np.zeros((seg.shape[1], seg.shape[2], seg.shape[3])) |
| seg_out[seg[1] == 1] = 2 |
| seg_out[seg[0] == 1] = 1 |
| seg_out[seg[2] == 1] = 4 |
|
|
| cached_images = image_tensor.cpu().numpy()[0, 0] |
| cached_masks = seg_out |
| return f"Segmentation done. Use the slider to browse slices." |
| except Exception as e: |
| return f"❌Error: {str(e)}" |
|
|
| |
| def get_slice_overlay(slice_index): |
| global cached_images, cached_masks |
| if len(cached_images) == 0: |
| return None |
| return get_overlay_figure(cached_images, cached_masks, slice_index) |
|
|
| with gr.Blocks() as iface: |
| gr.Markdown("### 🧠 SwinUNETR Brain Tumor Segmentation Viewer") |
|
|
| gr.Markdown("#### 📂 Upload MRI Modalities") |
| with gr.Row(): |
| flair = gr.File(file_types=[".nii", ".nii.gz"], label="FLAIR (.nii)") |
| t1 = gr.File(file_types=[".nii", ".nii.gz"], label="T1 (.nii)") |
| t1ce = gr.File(file_types=[".nii", ".nii.gz"], label="T1ce (.nii)") |
| t2 = gr.File(file_types=[".nii", ".nii.gz"], label="T2 (.nii)") |
|
|
| |
| gr.Markdown("#### 🧪 Or use example data") |
| examples = gr.Examples( |
| examples=[ |
| ["examples/flair.nii", "examples/t1.nii", "examples/t1ce.nii", "examples/t2.nii"] |
| ], |
| inputs=[flair, t1, t1ce, t2], |
| label="Example MRI files" |
| ) |
|
|
| output_msg = gr.Textbox(label="📣 Status") |
| run_button = gr.Button("▶️ Run Segmentation") |
|
|
| run_button.click(fn=predict_and_store, |
| inputs=[flair, t1, t1ce, t2], |
| outputs=output_msg) |
|
|
| gr.Markdown("### 🖼️ View Slices with Overlay") |
| slice_slider = gr.Slider(minimum=0, maximum=127, value=64, step=1, label="Slice Index") |
| slice_image = gr.Image(type="pil", label="Segmentation Overlay") |
|
|
| slice_slider.change(fn=get_slice_overlay, |
| inputs=slice_slider, |
| outputs=slice_image) |
|
|
| iface.launch(share=True) |
|
|