Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parent.parent | |
| sys.path.append(str(ROOT)) | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from torchvision import transforms | |
| from model.unet import UNet | |
| from PIL import Image | |
| import matplotlib.pyplot as plt | |
| import io | |
| # Load model | |
| ## using the 20th epoch | |
| model_path = "model/unet_epoch20.pth" | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = UNet(in_channels=1, out_channels=3).to(device) | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model.eval() | |
| # Preprocessing step | |
| def preprocess_image(img): | |
| gray = img.convert("L").resize((128, 128)) | |
| array = np.array(gray).astype(np.float32) / 255.0 | |
| tensor = torch.from_numpy(array).unsqueeze(0).unsqueeze(0).to(device) | |
| return tensor | |
| # Prediction step | |
| def segment(image): | |
| # Preprocess | |
| input_tensor = preprocess_image(image) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| pred = torch.argmax(output.squeeze(), dim=0).cpu().numpy() | |
| # Class labels | |
| class_labels = { | |
| 0: "Background", | |
| 1: "Gray Matter", | |
| 2: "White Matter" | |
| } | |
| # Stats | |
| total_pixels = pred.size | |
| stats = {} | |
| for i in range(3): | |
| count = (pred == i).sum() | |
| percent = count / total_pixels | |
| stats[f"{class_labels[i]}"] = round(percent, 2) | |
| # Convert original to grayscale numpy | |
| base_img = np.array(image.convert("L").resize((128, 128))) | |
| # Create color mask | |
| overlay = np.zeros((128, 128, 3), dtype=np.uint8) | |
| # Define RGB colors for each class | |
| colors = { | |
| 1: [255, 0, 0], # Red for gray matter | |
| 2: [0, 255, 0] # Green for white matter | |
| } | |
| for cls, color in colors.items(): | |
| overlay[pred == cls] = color | |
| # Blend overlay with original image | |
| blended = np.stack([base_img]*3, axis=-1) | |
| alpha = 0.5 # transparency | |
| blended = (1 - alpha) * blended + alpha * overlay | |
| blended = blended.astype(np.uint8) | |
| # Convert to image | |
| output_image = Image.fromarray(blended) | |
| # # Plot side-by-side | |
| # fig, axes = plt.subplots(1, 2, figsize=(6, 3)) | |
| # axes[0].imshow(image.convert("L").resize((128, 128)), cmap="gray") | |
| # axes[0].set_title("Original") | |
| # axes[1].imshow(pred, cmap="jet") | |
| # axes[1].set_title("Prediction") | |
| # for ax in axes: | |
| # ax.axis("off") | |
| # buf = io.BytesIO() | |
| # plt.tight_layout() | |
| # plt.savefig(buf, format="png") | |
| # plt.close(fig) | |
| # buf.seek(0) | |
| # output_image = Image.open(buf) | |
| # Visualize only the predicted mask | |
| # fig, ax = plt.subplots(figsize=(3, 3)) | |
| # ax.imshow(pred, cmap="jet") | |
| # ax.set_title("Segmentation") | |
| # ax.axis("off") | |
| # # Save to buffer | |
| # buf = io.BytesIO() | |
| # plt.tight_layout() | |
| # plt.savefig(buf, format="png") | |
| # plt.close(fig) | |
| # buf.seek(0) | |
| # output_image = Image.open(buf) | |
| return output_image, stats | |
| # Gradio UI | |
| demo = gr.Interface( | |
| fn=segment, | |
| inputs=gr.Image(type="pil", label="Upload Brain MRI (.jpg, .png)"), | |
| outputs=[ | |
| gr.Image(label="Segmented Output"), | |
| gr.Label(label="Tissue Composition (%)") | |
| ], | |
| title="Brain MRI Segmentation (U-Net)", | |
| description="Upload a T1-weighted brain slice. The model will segment it into tissue classes and estimate area percentages.", | |
| allow_flagging="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |