ma4389's picture
Update app.py
8b3c9aa verified
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import gradio as gr
import numpy as np
from PIL import Image
import cv2
############################################
# ========== ORIGINAL TRAINING UNET =========
############################################
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv_op = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.conv_op(x)
class DownSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.conv = DoubleConv(in_channels, out_channels)
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
down = self.conv(x)
p = self.pool(down)
return down, p
class UpSample(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.up = nn.ConvTranspose2d(in_channels, in_channels//2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
x = torch.cat([x1, x2], 1)
return self.conv(x)
class UNet(nn.Module):
def __init__(self, in_channels=3, num_classes=1):
super().__init__()
self.down_convolution_1 = DownSample(in_channels, 64)
self.down_convolution_2 = DownSample(64, 128)
self.down_convolution_3 = DownSample(128, 256)
self.down_convolution_4 = DownSample(256, 512)
self.bottle_neck = DoubleConv(512, 1024)
self.up_convolution_1 = UpSample(1024, 512)
self.up_convolution_2 = UpSample(512, 256)
self.up_convolution_3 = UpSample(256, 128)
self.up_convolution_4 = UpSample(128, 64)
self.out = nn.Conv2d(in_channels=64, out_channels=num_classes, kernel_size=1)
def forward(self, x):
down_1, p1 = self.down_convolution_1(x)
down_2, p2 = self.down_convolution_2(p1)
down_3, p3 = self.down_convolution_3(p2)
down_4, p4 = self.down_convolution_4(p3)
b = self.bottle_neck(p4)
up_1 = self.up_convolution_1(b, down_4)
up_2 = self.up_convolution_2(up_1, down_3)
up_3 = self.up_convolution_3(up_2, down_2)
up_4 = self.up_convolution_4(up_3, down_1)
return self.out(up_4)
############################################
# ========== LOAD MODEL ====================
############################################
device = torch.device("cpu")
model = UNet(in_channels=3, num_classes=1)
model.load_state_dict(torch.load("my_checkpoint.pth", map_location=device))
model.eval()
############################################
# ========== TRANSFORM =====================
############################################
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
############################################
# ========== DICE FUNCTION =================
############################################
def dice_coefficient(pred, target, epsilon=1e-7):
pred = (pred > 0.5).float()
intersection = (pred * target).sum()
union = pred.sum() + target.sum()
return ((2. * intersection + epsilon) / (union + epsilon)).item()
############################################
# ========== TIFF SAFE LOADER ==============
############################################
def load_image(file):
img = Image.open(file.name)
img_np = np.array(img)
# Handle 16-bit TIFF
if img_np.dtype == np.uint16:
img_np = (img_np / 256).astype(np.uint8)
img_pil = Image.fromarray(img_np).convert("RGB")
return img_pil, img_np
############################################
# ========== PREDICTION ====================
############################################
def predict(image_file, mask_file=None):
if image_file is None:
return None, "Please upload an image."
image_pil, original_np = load_image(image_file)
input_tensor = transform(image_pil).unsqueeze(0)
with torch.no_grad():
output = model(input_tensor)
output = torch.sigmoid(output)
pred_mask = output.squeeze().numpy()
pred_binary = (pred_mask > 0.5).astype(np.uint8)
# Resize mask back to original size
pred_resized = cv2.resize(
pred_binary,
(original_np.shape[1], original_np.shape[0])
)
overlay = original_np.copy()
overlay[pred_resized == 1] = [255, 0, 0]
if mask_file is not None:
mask_pil, _ = load_image(mask_file)
mask_tensor = transform(mask_pil.convert("L"))
dice = dice_coefficient(torch.tensor(pred_mask), mask_tensor)
return overlay, f"Dice Score: {round(dice,4)}"
return overlay, "Prediction complete."
############################################
# ========== GRADIO UI =====================
############################################
description = """
# 🧠 Brain Tumor Segmentation (UNet)
Dataset used for training:
https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation
Upload a `.tif` MRI image.
Optionally upload the ground-truth mask to compute Dice score.
"""
demo = gr.Interface(
fn=predict,
inputs=[
gr.File(file_types=[".tif", ".tiff", ".png", ".jpg"], label="Upload MRI Image"),
gr.File(file_types=[".tif", ".tiff"], label="Optional Ground Truth Mask")
],
outputs=[
gr.Image(label="Predicted Overlay"),
gr.Textbox(label="Result")
],
title="UNet Brain Tumor Segmentation",
description=description
)
if __name__ == "__main__":
demo.launch()