import torch import torch.nn as nn import torchvision.transforms as transforms import gradio as gr import numpy as np from PIL import Image import cv2 ############################################ # ========== ORIGINAL TRAINING UNET ========= ############################################ class DoubleConv(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv_op = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) def forward(self, x): return self.conv_op(x) class DownSample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.conv = DoubleConv(in_channels, out_channels) self.pool = nn.MaxPool2d(kernel_size=2, stride=2) def forward(self, x): down = self.conv(x) p = self.pool(down) return down, p class UpSample(nn.Module): def __init__(self, in_channels, out_channels): super().__init__() self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) self.conv = DoubleConv(in_channels, out_channels) def forward(self, x1, x2): x1 = self.up(x1) x = torch.cat([x1, x2], 1) return self.conv(x) class UNet(nn.Module): def __init__(self, in_channels=3, num_classes=1): super().__init__() self.down_convolution_1 = DownSample(in_channels, 64) self.down_convolution_2 = DownSample(64, 128) self.down_convolution_3 = DownSample(128, 256) self.down_convolution_4 = DownSample(256, 512) self.bottle_neck = DoubleConv(512, 1024) self.up_convolution_1 = UpSample(1024, 512) self.up_convolution_2 = UpSample(512, 256) self.up_convolution_3 = UpSample(256, 128) self.up_convolution_4 = UpSample(128, 64) self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1) def forward(self, x): down_1, p1 = self.down_convolution_1(x) down_2, p2 = self.down_convolution_2(p1) down_3, p3 = self.down_convolution_3(p2) down_4, p4 = self.down_convolution_4(p3) b = self.bottle_neck(p4) up_1 = self.up_convolution_1(b, down_4) up_2 = self.up_convolution_2(up_1, down_3) up_3 = self.up_convolution_3(up_2, down_2) up_4 = self.up_convolution_4(up_3, down_1) return self.out(up_4) ############################################ # ========== LOAD MODEL ==================== ############################################ device = torch.device("cpu") model = UNet(in_channels=3, num_classes=1) model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device)) model.eval() ############################################ # ========== TRANSFORM ===================== ############################################ transform = transforms.Compose([ transforms.Resize((256, 256)), transforms.ToTensor() ]) ############################################ # ========== DICE FUNCTION ================= ############################################ def dice_coefficient(pred, target, epsilon=1e-7): pred = (pred > 0.5).float() intersection = (pred * target).sum() union = pred.sum() + target.sum() return ((2. * intersection + epsilon) / (union + epsilon)).item() ############################################ # ========== TIFF SAFE LOADER ============== ############################################ def load_image(file): img = Image.open(file.name) img_np = np.array(img) # Handle 16-bit TIFF if img_np.dtype == np.uint16: img_np = (img_np / 256).astype(np.uint8) img_pil = Image.fromarray(img_np).convert("RGB") return img_pil, img_np ############################################ # ========== PREDICTION ==================== ############################################ def predict(image_file, mask_file=None): if image_file is None: return None, "Please upload an image." image_pil, original_np = load_image(image_file) input_tensor = transform(image_pil).unsqueeze(0) with torch.no_grad(): output = model(input_tensor) output = torch.sigmoid(output) pred_mask = output.squeeze().numpy() pred_binary = (pred_mask > 0.5).astype(np.uint8) # Resize mask back to original size pred_resized = cv2.resize( pred_binary, (original_np.shape[1], original_np.shape[0]) ) overlay = original_np.copy() overlay[pred_resized == 1] = [255, 0, 0] if mask_file is not None: mask_pil, _ = load_image(mask_file) mask_tensor = transform(mask_pil.convert("L")) dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor) return overlay, f"Dice Score: {round(dice,4)}" return overlay, "Prediction complete." ############################################ # ========== GRADIO UI ===================== ############################################ description = """ # 🧠 Brain Tumor Segmentation (UNet) Dataset used for training: https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation Upload a `.tif` MRI image. Optionally upload the ground-truth mask to compute Dice score. """ demo = gr.Interface( fn=predict, inputs=[ gr.File(file_types=[".tif", ".tiff", ".png", ".jpg"], label="Upload MRI Image"), gr.File(file_types=[".tif", ".tiff"], label="Optional Ground Truth Mask") ], outputs=[ gr.Image(label="Predicted Overlay"), gr.Textbox(label="Result") ], title="UNet Brain Tumor Segmentation", description=description ) if __name__ == "__main__": demo.launch()