Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as transforms | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| import cv2 | |
| ############################################ | |
| # ========== ORIGINAL TRAINING UNET ========= | |
| ############################################ | |
| class DoubleConv(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv_op = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), | |
| nn.ReLU(inplace=True) | |
| ) | |
| def forward(self, x): | |
| return self.conv_op(x) | |
| class DownSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| self.pool = nn.MaxPool2d(kernel_size=2, stride=2) | |
| def forward(self, x): | |
| down = self.conv(x) | |
| p = self.pool(down) | |
| return down, p | |
| class UpSample(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super().__init__() | |
| self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2) | |
| self.conv = DoubleConv(in_channels, out_channels) | |
| def forward(self, x1, x2): | |
| x1 = self.up(x1) | |
| x = torch.cat([x1, x2], 1) | |
| return self.conv(x) | |
| class UNet(nn.Module): | |
| def __init__(self, in_channels=3, num_classes=1): | |
| super().__init__() | |
| self.down_convolution_1 = DownSample(in_channels, 64) | |
| self.down_convolution_2 = DownSample(64, 128) | |
| self.down_convolution_3 = DownSample(128, 256) | |
| self.down_convolution_4 = DownSample(256, 512) | |
| self.bottle_neck = DoubleConv(512, 1024) | |
| self.up_convolution_1 = UpSample(1024, 512) | |
| self.up_convolution_2 = UpSample(512, 256) | |
| self.up_convolution_3 = UpSample(256, 128) | |
| self.up_convolution_4 = UpSample(128, 64) | |
| self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1) | |
| def forward(self, x): | |
| down_1, p1 = self.down_convolution_1(x) | |
| down_2, p2 = self.down_convolution_2(p1) | |
| down_3, p3 = self.down_convolution_3(p2) | |
| down_4, p4 = self.down_convolution_4(p3) | |
| b = self.bottle_neck(p4) | |
| up_1 = self.up_convolution_1(b, down_4) | |
| up_2 = self.up_convolution_2(up_1, down_3) | |
| up_3 = self.up_convolution_3(up_2, down_2) | |
| up_4 = self.up_convolution_4(up_3, down_1) | |
| return self.out(up_4) | |
| ############################################ | |
| # ========== LOAD MODEL ==================== | |
| ############################################ | |
| device = torch.device("cpu") | |
| model = UNet(in_channels=3, num_classes=1) | |
| model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device)) | |
| model.eval() | |
| ############################################ | |
| # ========== TRANSFORM ===================== | |
| ############################################ | |
| transform = transforms.Compose([ | |
| transforms.Resize((256, 256)), | |
| transforms.ToTensor() | |
| ]) | |
| ############################################ | |
| # ========== DICE FUNCTION ================= | |
| ############################################ | |
| def dice_coefficient(pred, target, epsilon=1e-7): | |
| pred = (pred > 0.5).float() | |
| intersection = (pred * target).sum() | |
| union = pred.sum() + target.sum() | |
| return ((2. * intersection + epsilon) / (union + epsilon)).item() | |
| ############################################ | |
| # ========== TIFF SAFE LOADER ============== | |
| ############################################ | |
| def load_image(file): | |
| img = Image.open(file.name) | |
| img_np = np.array(img) | |
| # Handle 16-bit TIFF | |
| if img_np.dtype == np.uint16: | |
| img_np = (img_np / 256).astype(np.uint8) | |
| img_pil = Image.fromarray(img_np).convert("RGB") | |
| return img_pil, img_np | |
| ############################################ | |
| # ========== PREDICTION ==================== | |
| ############################################ | |
| def predict(image_file, mask_file=None): | |
| if image_file is None: | |
| return None, "Please upload an image." | |
| image_pil, original_np = load_image(image_file) | |
| input_tensor = transform(image_pil).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| output = torch.sigmoid(output) | |
| pred_mask = output.squeeze().numpy() | |
| pred_binary = (pred_mask > 0.5).astype(np.uint8) | |
| # Resize mask back to original size | |
| pred_resized = cv2.resize( | |
| pred_binary, | |
| (original_np.shape[1], original_np.shape[0]) | |
| ) | |
| overlay = original_np.copy() | |
| overlay[pred_resized == 1] = [255, 0, 0] | |
| if mask_file is not None: | |
| mask_pil, _ = load_image(mask_file) | |
| mask_tensor = transform(mask_pil.convert("L")) | |
| dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor) | |
| return overlay, f"Dice Score: {round(dice,4)}" | |
| return overlay, "Prediction complete." | |
| ############################################ | |
| # ========== GRADIO UI ===================== | |
| ############################################ | |
| description = """ | |
| # 🧠 Brain Tumor Segmentation (UNet) | |
| Dataset used for training: | |
| https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation | |
| Upload a `.tif` MRI image. | |
| Optionally upload the ground-truth mask to compute Dice score. | |
| """ | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=[ | |
| gr.File(file_types=[".tif", ".tiff", ".png", ".jpg"], label="Upload MRI Image"), | |
| gr.File(file_types=[".tif", ".tiff"], label="Optional Ground Truth Mask") | |
| ], | |
| outputs=[ | |
| gr.Image(label="Predicted Overlay"), | |
| gr.Textbox(label="Result") | |
| ], | |
| title="UNet Brain Tumor Segmentation", | |
| description=description | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |