Brain_Unet_App / app.py
AndaiMD's picture
READme
f883a54
import sys
from pathlib import Path
ROOT = Path(__file__).resolve().parent.parent
sys.path.append(str(ROOT))
import gradio as gr
import torch
import numpy as np
from torchvision import transforms
from model.unet import UNet
from PIL import Image
import matplotlib.pyplot as plt
import io
# Load model
## using the 20th epoch
model_path = "model/unet_epoch20.pth"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = UNet(in_channels=1, out_channels=3).to(device)
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
# Preprocessing step
def preprocess_image(img):
gray = img.convert("L").resize((128, 128))
array = np.array(gray).astype(np.float32) / 255.0
tensor = torch.from_numpy(array).unsqueeze(0).unsqueeze(0).to(device)
return tensor
# Prediction step
def segment(image):
# Preprocess
input_tensor = preprocess_image(image)
with torch.no_grad():
output = model(input_tensor)
pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy()
# Class labels
class_labels = {
0: "Background",
1: "Gray Matter",
2: "White Matter"
}
# Stats
total_pixels = pred.size
stats = {}
for i in range(3):
count = (pred == i).sum()
percent = count / total_pixels
stats[f"{class_labels[i]}"] = round(percent, 2)
# Convert original to grayscale numpy
base_img = np.array(image.convert("L").resize((128, 128)))
# Create color mask
overlay = np.zeros((128, 128, 3), dtype=np.uint8)
# Define RGB colors for each class
colors = {
1: [255, 0, 0], # Red for gray matter
2: [0, 255, 0] # Green for white matter
}
for cls, color in colors.items():
overlay[pred == cls] = color
# Blend overlay with original image
blended = np.stack([base_img]*3, axis=-1)
alpha = 0.5 # transparency
blended = (1 - alpha) * blended + alpha * overlay
blended = blended.astype(np.uint8)
# Convert to image
output_image = Image.fromarray(blended)
# # Plot side-by-side
# fig, axes = plt.subplots(1, 2, figsize=(6, 3))
# axes[0].imshow(image.convert("L").resize((128, 128)), cmap="gray")
# axes[0].set_title("Original")
# axes[1].imshow(pred, cmap="jet")
# axes[1].set_title("Prediction")
# for ax in axes:
# ax.axis("off")
# buf = io.BytesIO()
# plt.tight_layout()
# plt.savefig(buf, format="png")
# plt.close(fig)
# buf.seek(0)
# output_image = Image.open(buf)
# Visualize only the predicted mask
# fig, ax = plt.subplots(figsize=(3, 3))
# ax.imshow(pred, cmap="jet")
# ax.set_title("Segmentation")
# ax.axis("off")
# # Save to buffer
# buf = io.BytesIO()
# plt.tight_layout()
# plt.savefig(buf, format="png")
# plt.close(fig)
# buf.seek(0)
# output_image = Image.open(buf)
return output_image, stats
# Gradio UI
demo = gr.Interface(
fn=segment,
inputs=gr.Image(type="pil", label="Upload Brain MRI (.jpg, .png)"),
outputs=[
gr.Image(label="Segmented Output"),
gr.Label(label="Tissue Composition (%)")
],
title="Brain MRI Segmentation (U-Net)",
description="Upload a T1-weighted brain slice. The model will segment it into tissue classes and estimate area percentages.",
allow_flagging="never"
)
if __name__ == "__main__":
demo.launch()