FibrilSegNet / app.py
himanshuch8055's picture
Remove mean and standard confidence output from prediction results; update predicted mask image
671374a
# # +++++++++++++++ Fist, Version: 1.0.0 +++++++++++++++++++++
# # Last Updated: 08 July 2025
# # Fibril Segmentation with UNet++ using Gradio
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# # ─── Configuration ─────────────────────────────────────────
# CONFIG = {
# "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "img_size": 512
# }
# # ─── Device Setup ──────────────────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Load Model ────────────────────────────────────────────
# model = smp.UnetPlusPlus(
# encoder_name='resnet34',
# encoder_depth=5,
# encoder_weights='imagenet',
# decoder_use_norm='batchnorm',
# decoder_channels=(256, 128, 64, 32, 16),
# decoder_attention_type=None,
# decoder_interpolation='nearest',
# in_channels=1,
# classes=1,
# activation=None
# ).to(device)
# model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
# model.eval()
# # ─── Transform Function ────────────────────────────────────
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(CONFIG["img_size"])
# # ─── Prediction Function ───────────────────────────────────
# def predict(image):
# image = image.convert("L") # Convert to grayscale
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# with torch.no_grad():
# pred = torch.sigmoid(model(img_tensor))
# mask = (pred > 0.5).float().cpu().squeeze().numpy()
# mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# return mask_img
# # ─── Gradio Interface ──────────────────────────────────────
# demo = gr.Interface(
# fn=predict,
# inputs=gr.Image(type="pil", label="Upload Microscopy Image"),
# outputs=gr.Image(type="pil", label="Predicted Segmentation Mask"),
# title="Fibril Segmentation with Unet++",
# description="Upload a grayscale microscopy image to get its predicted segmentation mask.",
# allow_flagging="never",
# live=False
# )
# if __name__ == "__main__":
# demo.launch()
# # +++++++++++++++ Second Version: 1.1.0 ++++++++++++++++++++
# # Last Updated: 08 July 2025
# # Improvements: Added examples, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# # ─── Configuration ─────────────────────────────────────────
# CONFIG = {
# "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "img_size": 512
# }
# # ─── Device Setup ──────────────────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Load Model ────────────────────────────────────────────
# model = smp.UnetPlusPlus(
# encoder_name='resnet34',
# encoder_depth=5,
# encoder_weights='imagenet',
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# ).to(device)
# model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
# model.eval()
# # ─── Transform Function ────────────────────────────────────
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(CONFIG["img_size"])
# # ─── Prediction Function ───────────────────────────────────
# def predict(image):
# image = image.convert("L") # Ensure grayscale
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# with torch.no_grad():
# pred = torch.sigmoid(model(img_tensor))
# mask = (pred > 0.5).float().cpu().squeeze().numpy()
# mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# return mask_img
# # ─── Gradio UI (Improved) ──────────────────────────────────
# examples = [
# ["examples/example1.jpg"],
# ["examples/example2.jpg"],
# ["examples/example3.jpg"],
# ["examples/example4.jpg"],
# ["examples/example5.jpg"],
# ["examples/example6.jpg"],
# ["examples/example7.jpg"]
# ]
# css = """
# .gradio-container {
# max-width: 950px;
# margin: auto;
# }
# .gr-button {
# background-color: #4a90e2;
# color: white;
# border-radius: 5px;
# }
# .gr-button:hover {
# background-color: #357ABD;
# }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("## 🧬 Fibril Segmentation with UNet++")
# gr.Markdown("Upload a **grayscale microscopy image**, and this model will predict the **segmentation mask of fibrillar structures**.\n\nModel: ResNet34 encoder + UNet++ decoder")
# with gr.Row():
# input_img = gr.Image(label="Upload Microscopy Image", type="pil")
# output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil")
# submit_btn = gr.Button("Segment Image")
# submit_btn.click(fn=predict, inputs=input_img, outputs=output_mask)
# gr.Examples(
# examples=examples,
# inputs=input_img,
# label="Try with Example Images",
# cache_examples=False
# )
# # ─── Launch App ────────────────────────────────────────────
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.2.0 ++++++++++++++++++++++
# Last Updated: 08 July 2025
# Improvements: Added model selection, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# # ─── Device Setup ──────────────────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Model Configurations ──────────────────────────────────
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus"
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet"
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus"
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus"
# }
# }
# # ─── Transform Function ────────────────────────────────────
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# # ─── Model Loader ──────────────────────────────────────────
# # def load_model(model_name):
# # config = MODEL_OPTIONS[model_name]
# # if config["architecture"] == "UnetPlusPlus":
# # model = smp.UnetPlusPlus(
# # encoder_name=config["encoder"],
# # encoder_weights="imagenet",
# # decoder_channels=(256, 128, 64, 32, 16),
# # in_channels=1,
# # classes=1,
# # activation=None
# # )
# # elif config["architecture"] == "Unet":
# # model = smp.Unet(
# # encoder_name=config["encoder"],
# # encoder_weights="imagenet",
# # decoder_channels=(256, 128, 64, 32, 16),
# # in_channels=1,
# # classes=1,
# # activation=None
# # )
# # else:
# # raise ValueError(f"Unsupported architecture: {config['architecture']}")
# # model.load_state_dict(torch.load(config["path"], map_location=device))
# # model.eval()
# # return model.to(device)
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# # decoder_channels=(256, 128, 64, 32, 16), # Not used in DeepLabV3Plus
# in_channels=1,
# classes=1,
# activation=None
# )
# else:
# raise ValueError(f"Unsupported architecture: {config['architecture']}")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# # ─── Prediction Function ───────────────────────────────────
# def predict(image, model_name):
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# with torch.no_grad():
# pred = torch.sigmoid(model(img_tensor))
# mask = (pred > 0.5).float().cpu().squeeze().numpy()
# mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# return mask_img
# # ─── Example Images ────────────────────────────────────────
# examples = [
# ["examples/example1.jpg"],
# ["examples/example2.jpg"],
# ["examples/example3.jpg"],
# ["examples/example4.jpg"],
# ["examples/example5.jpg"],
# ["examples/example6.jpg"],
# ["examples/example7.jpg"]
# ]
# # ─── Custom CSS ────────────────────────────────────────────
# css = """
# .gradio-container {
# max-width: 950px;
# margin: auto;
# }
# .gr-button {
# background-color: #4a90e2;
# color: white;
# border-radius: 5px;
# }
# .gr-button:hover {
# background-color: #357ABD;
# }
# """
# # ─── Gradio UI ─────────────────────────────────────────────
# with gr.Blocks(css=css) as demo:
# gr.Markdown("## 🧬 Fibril Segmentation Interface")
# gr.Markdown("Choose a model and upload a grayscale microscopy image. The model will predict the **fibrillar structure mask**.")
# with gr.Row():
# model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Select Model")
# with gr.Row():
# input_img = gr.Image(label="Upload Microscopy Image", type="pil")
# output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil")
# submit_btn = gr.Button("Segment Image")
# submit_btn.click(fn=predict, inputs=[input_img, model_selector], outputs=output_mask)
# gr.Examples(
# examples=examples,
# inputs=input_img,
# label="Try with Example Images",
# cache_examples=False
# )
# # ─── Launch App ────────────────────────────────────────────
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.3.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# # ─── Device Setup ──────────────────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Model Configurations ──────────────────────────────────
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet",
# "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus",
# "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
# }
# }
# # ─── Transform Function ────────────────────────────────────
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# # ─── Model Loader ──────────────────────────────────────────
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# in_channels=1,
# classes=1,
# activation=None
# )
# else:
# raise ValueError(f"Unsupported architecture: {config['architecture']}")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# # ─── Prediction Function ───────────────────────────────────
# @torch.no_grad()
# def predict(image, model_name, threshold):
# if image is None:
# return "❌ Please upload an image.", None, None, None
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# pred = torch.sigmoid(model(img_tensor))
# prob_map = pred.cpu().squeeze().numpy()
# mask = (prob_map > threshold).astype(np.float32)
# mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# prob_img = Image.fromarray((prob_map * 255).astype(np.uint8))
# # Save mask temporarily for download
# out_path = "predicted_mask.png"
# mask_img.save(out_path)
# return "✅ Done!", mask_img, prob_img, out_path
# # ─── UI and Layout ─────────────────────────────────────────
# css = """
# .gradio-container {
# max-width: 950px;
# margin: auto;
# }
# .gr-button {
# background-color: #4a90e2;
# color: white;
# border-radius: 5px;
# }
# .gr-button:hover {
# background-color: #357ABD;
# }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("## 🧬 Fibril Segmentation Interface")
# gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.")
# with gr.Row():
# model_selector = gr.Dropdown(
# choices=list(MODEL_OPTIONS.keys()),
# value="UNet++ (ResNet34)",
# label="Select Model"
# )
# threshold_slider = gr.Slider(
# minimum=0.0, maximum=1.0, value=0.5, step=0.01,
# label="Segmentation Threshold"
# )
# model_info = gr.Textbox(
# label="Model Description",
# lines=3,
# interactive=False,
# value=MODEL_OPTIONS[model_selector.value]["description"]
# )
# def update_model_info(model_name):
# return MODEL_OPTIONS[model_name]["description"]
# model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
# with gr.Tabs():
# with gr.Tab("🖼️ Segmentation"):
# with gr.Row():
# input_img = gr.Image(
# label="Upload Microscopy Image",
# type="pil",
# )
# output_mask = gr.Image(label="Predicted Mask", type="pil")
# confidence_map = gr.Image(label="Confidence Map", type="pil")
# submit_btn = gr.Button("Segment Image")
# status_box = gr.Textbox(label="Status", interactive=False)
# output_download = gr.File(label="Download Predicted Mask")
# submit_btn.click(
# fn=predict,
# inputs=[input_img, model_selector, threshold_slider],
# outputs=[status_box, output_mask, confidence_map, output_download]
# )
# with gr.Tab("📁 Examples"):
# gr.Examples(
# examples=[
# ["examples/example1.jpg"],
# ["examples/example2.jpg"],
# ["examples/example3.jpg"],
# ["examples/example4.jpg"],
# ["examples/example5.jpg"],
# ["examples/example6.jpg"],
# ["examples/example7.jpg"]
# ],
# inputs=input_img,
# label="Try with Example Images",
# cache_examples=False
# )
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.4.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# # ─── Device Setup ──────────────────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Model Configurations ──────────────────────────────────
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet",
# "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus",
# "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
# }
# }
# # ─── Transform Function ────────────────────────────────────
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# # ─── Model Loader ──────────────────────────────────────────
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1,
# classes=1,
# activation=None
# )
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"],
# encoder_weights="imagenet",
# in_channels=1,
# classes=1,
# activation=None
# )
# else:
# raise ValueError(f"Unsupported architecture: {config['architecture']}")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# # ─── Prediction Function ───────────────────────────────────
# @torch.no_grad()
# def predict(image, model_name, threshold):
# if image is None:
# return "❌ Please upload an image.", None, None, None
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# pred = torch.sigmoid(model(img_tensor))
# prob_map = pred.cpu().squeeze().numpy()
# mask = (prob_map > threshold).astype(np.float32)
# mask_img = Image.fromarray((mask * 255).astype(np.uint8))
# prob_img = Image.fromarray((prob_map * 255).astype(np.uint8))
# out_path = "predicted_mask.png"
# mask_img.save(out_path)
# return "✅ Done!", mask_img, prob_img, out_path
# # ─── UI and Layout ─────────────────────────────────────────
# css = """
# .gradio-container {
# max-width: 1100px;
# margin: auto;
# }
# .gr-button {
# background-color: #4a90e2;
# color: white;
# border-radius: 5px;
# }
# .gr-button:hover {
# background-color: #357ABD;
# }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("## 🧬 Fibril Segmentation Interface")
# gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.")
# with gr.Row():
# # Left Panel – Examples and Upload
# with gr.Column(scale=1):
# gr.Markdown("### 📁 Example Images")
# input_img = gr.Image(
# label="Upload Microscopy Image",
# type="pil",
# interactive=True
# )
# gr.Examples(
# examples=[
# ["examples/example1.jpg"],
# ["examples/example2.jpg"],
# ["examples/example3.jpg"],
# ["examples/example4.jpg"],
# ["examples/example5.jpg"],
# ["examples/example6.jpg"],
# ["examples/example7.jpg"]
# ],
# inputs=input_img,
# label="Try with Example Images",
# cache_examples=False
# )
# # Right Panel – Controls & Output
# with gr.Column(scale=2):
# with gr.Row():
# model_selector = gr.Dropdown(
# choices=list(MODEL_OPTIONS.keys()),
# value="UNet++ (ResNet34)",
# label="Select Model"
# )
# threshold_slider = gr.Slider(
# minimum=0.0, maximum=1.0, value=0.5, step=0.01,
# label="Segmentation Threshold"
# )
# model_info = gr.Textbox(
# label="Model Description",
# lines=3,
# interactive=False,
# value=MODEL_OPTIONS["UNet++ (ResNet34)"]["description"]
# )
# def update_model_info(model_name):
# return MODEL_OPTIONS[model_name]["description"]
# model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
# with gr.Row():
# output_mask = gr.Image(label="Predicted Mask", type="pil")
# confidence_map = gr.Image(label="Confidence Map", type="pil")
# submit_btn = gr.Button("Segment Image")
# status_box = gr.Textbox(label="Status", interactive=False)
# output_download = gr.File(label="Download Predicted Mask")
# submit_btn.click(
# fn=predict,
# inputs=[input_img, model_selector, threshold_slider],
# outputs=[status_box, output_mask, confidence_map, output_download]
# )
# # ─── Launch App ────────────────────────────────────────────
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.5.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# from skimage import filters, measure, morphology
# # ─── Device Setup ─────────────────────────────
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"✅ Using device: {device}")
# # ─── Model Configurations ─────────────────────
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet",
# "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus",
# "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
# }
# }
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# in_channels=1, classes=1, activation=None)
# else:
# raise ValueError("Unsupported architecture.")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# # ─── Prediction Logic ─────────────────────────
# @torch.no_grad()
# def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats):
# if image is None:
# return "❌ Please upload an image.", None, None, None, "", ""
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
# if use_otsu:
# threshold = filters.threshold_otsu(pred)
# binary_mask = (pred > threshold).astype(np.float32)
# # Post-process mask
# if remove_noise:
# binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
# if fill_holes:
# binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
# binary_mask = binary_mask.astype(np.float32)
# mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
# prob_img = Image.fromarray((pred * 255).astype(np.uint8))
# if show_overlay:
# mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
# overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
# else:
# overlay_img = None
# stats_text = ""
# if show_stats:
# labeled_mask = measure.label(binary_mask)
# area = np.sum(binary_mask)
# mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
# std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
# stats_text = (
# f"🧮 Stats:\n - Area (px): {area:.0f}\n"
# f" - Objects: {labeled_mask.max()}\n"
# f" - Mean Conf: {mean_conf:.3f}\n"
# f" - Std Conf: {std_conf:.3f}"
# )
# mask_img.save("predicted_mask.png")
# return "✅ Done!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text
# # ─── UI ───────────────────────────────────────
# css = """
# .gradio-container { max-width: 1100px; margin: auto; }
# .gr-button { background-color: #4a90e2; color: white; border-radius: 5px; }
# .gr-button:hover { background-color: #357ABD; }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("## 🧬 Fibril Segmentation Interface")
# gr.Markdown("Upload a grayscale microscopy image and choose options to segment fibrillar structures.")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(label="Upload Image", type="pil", interactive=True)
# gr.Examples(
# examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
# inputs=input_img,
# label="📁 Try Example Images", cache_examples=False
# )
# with gr.Column(scale=2):
# model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Model")
# threshold_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Segmentation Threshold")
# use_otsu = gr.Checkbox(label="Use Otsu Threshold", value=False)
# remove_noise = gr.Checkbox(label="Remove Small Objects", value=True)
# fill_holes = gr.Checkbox(label="Fill Holes in Mask", value=True)
# show_overlay = gr.Checkbox(label="Show Overlay on Original", value=True)
# show_stats = gr.Checkbox(label="Show Area & Confidence Stats", value=True)
# model_info = gr.Textbox(label="Model Description", interactive=False, lines=2)
# def update_model_info(name): return MODEL_OPTIONS[name]["description"]
# model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
# submit = gr.Button("Segment Image")
# status = gr.Textbox(label="Status", interactive=False)
# with gr.Row():
# mask_output = gr.Image(label="Binary Mask")
# prob_output = gr.Image(label="Confidence Map")
# overlay_output = gr.Image(label="Overlay")
# stats_output = gr.Textbox(label="Segmentation Stats", lines=4)
# file_output = gr.File(label="Download Mask")
# submit.click(
# fn=predict,
# inputs=[
# input_img, model_selector, threshold_slider, use_otsu,
# remove_noise, fill_holes, show_overlay, show_stats
# ],
# outputs=[
# status, mask_output, prob_output, overlay_output,
# file_output, stats_output
# ]
# )
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.6.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling
# import os
# import torch
# import numpy as np
# from PIL import Image
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# from skimage import filters, measure, morphology
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet",
# "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus",
# "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
# }
# }
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# in_channels=1, classes=1, activation=None)
# else:
# raise ValueError("Unsupported architecture.")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# @torch.no_grad()
# def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats):
# if image is None:
# return "❌ Please upload an image.", None, None, None, "", ""
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
# if use_otsu:
# threshold = filters.threshold_otsu(pred)
# binary_mask = (pred > threshold).astype(np.float32)
# if remove_noise:
# binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
# if fill_holes:
# binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
# binary_mask = binary_mask.astype(np.float32)
# mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
# prob_img = Image.fromarray((pred * 255).astype(np.uint8))
# if show_overlay:
# mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
# overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
# else:
# overlay_img = None
# stats_text = ""
# if show_stats:
# labeled_mask = measure.label(binary_mask)
# area = np.sum(binary_mask)
# mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
# std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
# stats_text = (
# f"🧮 Stats:\n - Area (px): {area:.0f}\n"
# f" - Objects: {labeled_mask.max()}\n"
# f" - Mean Conf: {mean_conf:.3f}\n"
# f" - Std Conf: {std_conf:.3f}"
# )
# mask_img.save("predicted_mask.png")
# return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text
# css = """
# body {
# background: #f9fafb;
# color: #2c3e50;
# font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
# }
# h1, h2, h3 {
# color: #34495e;
# margin-bottom: 0.2em;
# }
# .gradio-container {
# max-width: 1100px;
# margin: 1.5rem auto;
# padding: 1rem 2rem;
# }
# .gr-button {
# background-color: #0078d7;
# color: white;
# font-weight: 600;
# border-radius: 8px;
# padding: 12px 25px;
# }
# .gr-button:hover {
# background-color: #005a9e;
# }
# .gr-slider label, .gr-checkbox label {
# font-weight: 600;
# color: #34495e;
# }
# .gr-image input[type="file"] {
# border-radius: 8px;
# }
# .gr-file label {
# font-weight: 600;
# }
# .gr-textbox textarea {
# font-family: monospace;
# font-size: 0.9rem;
# background: #ecf0f1;
# border-radius: 6px;
# padding: 8px;
# }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>")
# gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"])
# gr.Examples(
# examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
# inputs=input_img,
# label="📁 Try Example Images",
# cache_examples=False,
# elem_id="examples"
# )
# with gr.Column(scale=1):
# model_selector = gr.Dropdown(
# choices=list(MODEL_OPTIONS.keys()),
# value="UNet++ (ResNet34)",
# label="Select Model",
# interactive=True
# )
# model_info = gr.Textbox(
# label="Model Description",
# interactive=False,
# lines=3,
# max_lines=5,
# elem_id="model-desc",
# show_label=True,
# container=True
# )
# def update_model_info(name):
# return MODEL_OPTIONS[name]["description"]
# model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
# gr.Markdown("### Segmentation Options")
# threshold_slider = gr.Slider(
# minimum=0, maximum=1, value=0.5, step=0.01,
# label="Segmentation Threshold",
# interactive=True,
# info="Adjust threshold for binarizing segmentation probability."
# )
# use_otsu = gr.Checkbox(
# label="Use Otsu Threshold",
# value=False,
# info="Automatically select optimal threshold using Otsu's method."
# )
# remove_noise = gr.Checkbox(
# label="Remove Small Objects",
# value=False,
# info="Remove small noise blobs from mask."
# )
# fill_holes = gr.Checkbox(
# label="Fill Holes in Mask",
# value=True,
# info="Fill small holes inside segmented objects."
# )
# show_overlay = gr.Checkbox(
# label="Show Overlay on Original",
# value=True,
# info="Display the mask overlaid on the original image."
# )
# show_stats = gr.Checkbox(
# label="Show Area & Confidence Stats",
# value=True,
# info="Display segmentation statistics like area and confidence."
# )
# submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
# gr.Markdown("---")
# status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg")
# with gr.Row():
# mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
# prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
# overlay_output = gr.Image(label="Overlay", interactive=False, type="pil")
# stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=6, elem_id="stats")
# file_output = gr.File(label="Download Segmentation Mask")
# submit.click(
# fn=predict,
# inputs=[
# input_img, model_selector, threshold_slider, use_otsu,
# remove_noise, fill_holes, show_overlay, show_stats
# ],
# outputs=[
# status, mask_output, prob_output, overlay_output,
# file_output, stats_output
# ]
# )
# demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info)
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.7.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
# import os
# import torch
# import numpy as np
# from PIL import Image, ImageDraw, ImageFont
# import albumentations as A
# from albumentations.pytorch import ToTensorV2
# import segmentation_models_pytorch as smp
# import gradio as gr
# from skimage import filters, measure, morphology
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# MODEL_OPTIONS = {
# "UNet++ (ResNet34)": {
# "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
# },
# "UNet (ResNet34)": {
# "path": "./model/unet_fibril_seg_model.pth",
# "encoder": "resnet34",
# "architecture": "Unet",
# "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
# },
# "UNet++ (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "UnetPlusPlus",
# "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
# },
# "DeepLabV3Plus (efficientnet-b3)": {
# "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
# "encoder": "efficientnet-b3",
# "architecture": "DeepLabV3Plus",
# "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
# }
# }
# def get_transform(size):
# return A.Compose([
# A.Resize(size, size),
# A.Normalize(mean=(0.5,), std=(0.5,)),
# ToTensorV2()
# ])
# transform = get_transform(512)
# model_cache = {}
# def load_model(model_name):
# if model_name in model_cache:
# return model_cache[model_name]
# config = MODEL_OPTIONS[model_name]
# if config["architecture"] == "UnetPlusPlus":
# model = smp.UnetPlusPlus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "Unet":
# model = smp.Unet(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# decoder_channels=(256, 128, 64, 32, 16),
# in_channels=1, classes=1, activation=None)
# elif config["architecture"] == "DeepLabV3Plus":
# model = smp.DeepLabV3Plus(
# encoder_name=config["encoder"], encoder_weights="imagenet",
# in_channels=1, classes=1, activation=None)
# else:
# raise ValueError("Unsupported architecture.")
# model.load_state_dict(torch.load(config["path"], map_location=device))
# model.eval()
# model_cache[model_name] = model.to(device)
# return model_cache[model_name]
# def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size):
# """
# Draws numbers on the overlay image labeling oligomers and fibrils.
# Oligomers labeled as O1, O2, ...
# Fibrils labeled as F1, F2, ...
# """
# overlay = orig_img.convert("RGB").copy()
# draw = ImageDraw.Draw(overlay)
# # Try to get a nice font; fallback to default if not available
# try:
# font = ImageFont.truetype("arial.ttf", 18)
# except IOError:
# font = ImageFont.load_default()
# labeled_mask = measure.label(binary_mask)
# regions = measure.regionprops(labeled_mask)
# oligomer_count = 0
# fibril_count = 0
# for region in regions:
# area = region.area
# centroid = region.centroid # (row, col)
# x, y = int(centroid[1]), int(centroid[0])
# if area <= max_oligomer_size:
# oligomer_count += 1
# label_text = f"O{oligomer_count}"
# label_color = (0, 255, 0) # Green for oligomers
# else:
# fibril_count += 1
# label_text = f"F{fibril_count}"
# label_color = (255, 0, 0) # Red for fibrils
# # Draw circle around centroid for visibility
# r = 12
# draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2)
# # Draw label text
# bbox = draw.textbbox((0, 0), label_text, font=font)
# text_width = bbox[2] - bbox[0]
# text_height = bbox[3] - bbox[1]
# text_pos = (x - text_width // 2, y - text_height // 2)
# draw.text(text_pos, label_text, fill=label_color, font=font)
# return overlay, oligomer_count, fibril_count
# @torch.no_grad()
# def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers):
# if image is None:
# return "❌ Please upload an image.", None, None, None, "", "", "", ""
# image = image.convert("L")
# img_np = np.array(image)
# img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
# model = load_model(model_name)
# pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
# if use_otsu:
# threshold = filters.threshold_otsu(pred)
# binary_mask = (pred > threshold).astype(np.float32)
# if remove_noise:
# binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
# if fill_holes:
# binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
# binary_mask = binary_mask.astype(np.float32)
# # If user wants to keep only oligomers, remove large fibrils from mask
# if keep_only_oligomers:
# labeled_mask = measure.label(binary_mask)
# regions = measure.regionprops(labeled_mask)
# filtered_mask = np.zeros_like(binary_mask)
# for region in regions:
# if region.area <= max_oligomer_size:
# filtered_mask[labeled_mask == region.label] = 1
# binary_mask = filtered_mask.astype(np.float32)
# mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
# prob_img = Image.fromarray((pred * 255).astype(np.uint8))
# overlay_img = None
# oligomer_count = 0
# fibril_count = 0
# if show_overlay:
# # Resize mask to original image size
# mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
# # Blend original and mask
# base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
# # Draw labels on overlay
# overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size)
# else:
# overlay_img = None
# stats_text = ""
# if show_stats:
# labeled_mask = measure.label(binary_mask)
# area = np.sum(binary_mask)
# mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
# std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
# total_objects = labeled_mask.max()
# # Count oligomers and fibrils
# oligomers = 0
# fibrils = 0
# for region in measure.regionprops(labeled_mask):
# if region.area <= max_oligomer_size:
# oligomers += 1
# else:
# fibrils += 1
# stats_text = (
# f"🧮 Stats:\n"
# f" - Area (px): {area:.0f}\n"
# f" - Total Objects: {total_objects}\n"
# f" - Oligomers (small): {oligomers}\n"
# f" - Fibrils (large): {fibrils}\n"
# f" - Mean Confidence: {mean_conf:.3f}\n"
# f" - Std Confidence: {std_conf:.3f}"
# )
# mask_img.save("predicted_mask.png")
# return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count)
# css = """
# body {
# background: #f9fafb;
# color: #2c3e50;
# font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
# }
# h1, h2, h3 {
# color: #34495e;
# margin-bottom: 0.2em;
# }
# .gradio-container {
# max-width: 1100px;
# margin: 1.5rem auto;
# padding: 1rem 2rem;
# }
# .gr-button {
# background-color: #0078d7;
# color: white;
# font-weight: 600;
# border-radius: 8px;
# padding: 12px 25px;
# }
# .gr-button:hover {
# background-color: #005a9e;
# }
# .gr-slider label, .gr-checkbox label {
# font-weight: 600;
# color: #34495e;
# }
# .gr-image input[type="file"] {
# border-radius: 8px;
# }
# .gr-file label {
# font-weight: 600;
# }
# .gr-textbox textarea {
# font-family: monospace;
# font-size: 0.9rem;
# background: #ecf0f1;
# border-radius: 6px;
# padding: 8px;
# }
# """
# with gr.Blocks(css=css) as demo:
# gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>")
# gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>")
# with gr.Row():
# with gr.Column(scale=1):
# input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"])
# gr.Examples(
# examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
# inputs=input_img,
# label="📁 Try Example Images",
# cache_examples=False,
# elem_id="examples"
# )
# with gr.Column(scale=1):
# model_selector = gr.Dropdown(
# choices=list(MODEL_OPTIONS.keys()),
# value="UNet++ (ResNet34)",
# label="Select Model",
# interactive=True
# )
# model_info = gr.Textbox(
# label="Model Description",
# interactive=False,
# lines=3,
# max_lines=5,
# elem_id="model-desc",
# show_label=True,
# container=True
# )
# def update_model_info(name):
# return MODEL_OPTIONS[name]["description"]
# model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
# gr.Markdown("### Segmentation Options")
# threshold_slider = gr.Slider(
# minimum=0, maximum=1, value=0.5, step=0.01,
# label="Segmentation Threshold",
# interactive=True,
# info="Adjust threshold for binarizing segmentation probability."
# )
# use_otsu = gr.Checkbox(
# label="Use Otsu Threshold",
# value=False,
# info="Automatically select optimal threshold using Otsu's method."
# )
# remove_noise = gr.Checkbox(
# label="Remove Small Objects",
# value=False,
# info="Remove small noise blobs from mask."
# )
# fill_holes = gr.Checkbox(
# label="Fill Holes in Mask",
# value=True,
# info="Fill small holes inside segmented objects."
# )
# show_overlay = gr.Checkbox(
# label="Show Overlay on Original",
# value=True,
# info="Display the mask overlaid on the original image."
# )
# show_stats = gr.Checkbox(
# label="Show Area & Confidence Stats",
# value=True,
# info="Display segmentation statistics like area and confidence."
# )
# max_oligomer_size = gr.Slider(
# minimum=10, maximum=1000, value=400, step=10,
# label="Max Oligomer Size (px)",
# interactive=True,
# info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils."
# )
# keep_only_oligomers = gr.Checkbox(
# label="Keep Only Oligomers (Remove Large Fibrils)",
# value=False,
# info="If enabled, only oligomers remain in the final mask."
# )
# submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
# gr.Markdown("---")
# status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg")
# with gr.Row():
# mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
# prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
# overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil")
# stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats")
# # Optional: show counts separately with big text
# oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1)
# fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1)
# file_output = gr.File(label="Download Segmentation Mask")
# submit.click(
# fn=predict,
# inputs=[
# input_img, model_selector, threshold_slider, use_otsu,
# remove_noise, fill_holes, show_overlay, show_stats,
# max_oligomer_size, keep_only_oligomers
# ],
# outputs=[
# status, mask_output, prob_output, overlay_output,
# file_output, stats_output, oligomer_count_txt, fibril_count_txt
# ]
# )
# demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info)
# if __name__ == "__main__":
# demo.launch()
# +++++++++++++++ Final Version: 1.8.0 ++++++++++++++++++++++
# Last Updated: 10 July 2025
# Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling
# +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
import os
import torch
import numpy as np
from PIL import Image, ImageDraw, ImageFont
import albumentations as A
from albumentations.pytorch import ToTensorV2
import segmentation_models_pytorch as smp
import gradio as gr
from skimage import filters, measure, morphology
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL_OPTIONS = {
"UNet++ (ResNet34)": {
"path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
"encoder": "resnet34",
"architecture": "UnetPlusPlus",
"description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
},
"UNet (ResNet34)": {
"path": "./model/unet_fibril_seg_model.pth",
"encoder": "resnet34",
"architecture": "Unet",
"description": "Classic UNet with ResNet34 encoder — fast and lightweight."
},
"UNet++ (efficientnet-b3)": {
"path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
"encoder": "efficientnet-b3",
"architecture": "UnetPlusPlus",
"description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
},
"DeepLabV3Plus (efficientnet-b3)": {
"path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
"encoder": "efficientnet-b3",
"architecture": "DeepLabV3Plus",
"description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
}
}
def get_transform(size):
return A.Compose([
A.Resize(size, size),
A.Normalize(mean=(0.5,), std=(0.5,)),
ToTensorV2()
])
transform = get_transform(512)
model_cache = {}
def load_model(model_name):
if model_name in model_cache:
return model_cache[model_name]
config = MODEL_OPTIONS[model_name]
if config["architecture"] == "UnetPlusPlus":
model = smp.UnetPlusPlus(
encoder_name=config["encoder"], encoder_weights="imagenet",
decoder_channels=(256, 128, 64, 32, 16),
in_channels=1, classes=1, activation=None)
elif config["architecture"] == "Unet":
model = smp.Unet(
encoder_name=config["encoder"], encoder_weights="imagenet",
decoder_channels=(256, 128, 64, 32, 16),
in_channels=1, classes=1, activation=None)
elif config["architecture"] == "DeepLabV3Plus":
model = smp.DeepLabV3Plus(
encoder_name=config["encoder"], encoder_weights="imagenet",
in_channels=1, classes=1, activation=None)
else:
raise ValueError("Unsupported architecture.")
model.load_state_dict(torch.load(config["path"], map_location=device))
model.eval()
model_cache[model_name] = model.to(device)
return model_cache[model_name]
def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size, fibril_length_thresh=100):
"""
Draw labels on the overlay image based on circularity and length.
Oligomers: High circularity (circle-like).
Fibrils: Long objects (based on major axis length).
fibril_length_thresh: Length threshold to define fibrils.
"""
overlay = orig_img.convert("RGB").copy()
draw = ImageDraw.Draw(overlay)
try:
font = ImageFont.truetype("arial.ttf", 18)
except IOError:
font = ImageFont.load_default()
labeled_mask = measure.label(binary_mask)
regions = measure.regionprops(labeled_mask)
oligomer_count = 0
fibril_count = 0
for region in regions:
area = region.area
perimeter = region.perimeter if region.perimeter > 0 else 1 # prevent div by zero
circularity = 4 * np.pi * area / (perimeter ** 2)
major_length = region.major_axis_length
centroid = region.centroid
x, y = int(centroid[1]), int(centroid[0])
# Thresholds to tune
circularity_thresh = 1 # close to circle
# max_oligomer_size can still be used for area-based filtering if desired
# Classification logic
if circularity >= circularity_thresh and area <= max_oligomer_size:
oligomer_count += 1
label_text = f"O{oligomer_count}"
label_color = (0, 255, 0) # Green for oligomers
elif major_length >= fibril_length_thresh:
fibril_count += 1
label_text = f"F{fibril_count}"
label_color = (255, 0, 0) # Red for fibrils
else:
# If neither circular nor long, you can optionally skip labeling or classify as fibril
fibril_count += 1
label_text = f"F{fibril_count}"
label_color = (255, 0, 0)
r = 12
draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2)
bbox = draw.textbbox((0, 0), label_text, font=font)
text_width = bbox[2] - bbox[0]
text_height = bbox[3] - bbox[1]
text_pos = (x - text_width // 2, y - text_height // 2)
draw.text(text_pos, label_text, fill=label_color, font=font)
return overlay, oligomer_count, fibril_count
@torch.no_grad()
def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers):
if image is None:
return "❌ Please upload an image.", None, None, None, "", "", "", ""
image = image.convert("L")
img_np = np.array(image)
img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
model = load_model(model_name)
pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
if use_otsu:
threshold = filters.threshold_otsu(pred)
binary_mask = (pred > threshold).astype(np.float32)
if remove_noise:
binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
if fill_holes:
binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
binary_mask = binary_mask.astype(np.float32)
# If user wants to keep only oligomers, remove large fibrils from mask
if keep_only_oligomers:
labeled_mask = measure.label(binary_mask)
regions = measure.regionprops(labeled_mask)
filtered_mask = np.zeros_like(binary_mask)
for region in regions:
if region.area <= max_oligomer_size:
filtered_mask[labeled_mask == region.label] = 1
binary_mask = filtered_mask.astype(np.float32)
mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
prob_img = Image.fromarray((pred * 255).astype(np.uint8))
overlay_img = None
oligomer_count = 0
fibril_count = 0
if show_overlay:
# Resize mask to original image size
mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
# Blend original and mask
base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
# Draw labels on overlay
# overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size)
overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size, fibril_length_thresh=100)
else:
overlay_img = None
stats_text = ""
if show_stats:
labeled_mask = measure.label(binary_mask)
area = np.sum(binary_mask)
mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
total_objects = labeled_mask.max()
# Count oligomers and fibrils
oligomers = 0
fibrils = 0
for region in measure.regionprops(labeled_mask):
if region.area <= max_oligomer_size:
oligomers += 1
else:
fibrils += 1
stats_text = (
f"🧮 Stats:\n"
f" - Area (px): {area:.0f}\n"
f" - Total Objects: {total_objects}\n"
f" - Oligomers (small): {oligomers}\n"
f" - Fibrils (large): {fibrils}\n"
# f" - Mean Confidence: {mean_conf:.3f}\n"
# f" - Std Confidence: {std_conf:.3f}"
)
mask_img.save("predicted_mask.png")
return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count)
css = """
body {
background: #f9fafb;
color: #2c3e50;
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
}
h1, h2, h3 {
color: #34495e;
margin-bottom: 0.2em;
}
.gradio-container {
max-width: 1100px;
margin: 1.5rem auto;
padding: 1rem 2rem;
}
.gr-button {
background-color: #0078d7;
color: white;
font-weight: 600;
border-radius: 8px;
padding: 12px 25px;
}
.gr-button:hover {
background-color: #005a9e;
}
.gr-slider label, .gr-checkbox label {
font-weight: 600;
color: #34495e;
}
.gr-image input[type="file"] {
border-radius: 8px;
}
.gr-file label {
font-weight: 600;
}
.gr-textbox textarea {
font-family: monospace;
font-size: 0.9rem;
background: #ecf0f1;
border-radius: 6px;
padding: 8px;
}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>")
gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>")
with gr.Row():
with gr.Column(scale=1):
input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"])
gr.Examples(
examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
inputs=input_img,
label="📁 Try Example Images",
cache_examples=False,
elem_id="examples"
)
with gr.Column(scale=1):
model_selector = gr.Dropdown(
choices=list(MODEL_OPTIONS.keys()),
value="UNet++ (ResNet34)",
label="Select Model",
interactive=True
)
model_info = gr.Textbox(
label="Model Description",
interactive=False,
lines=3,
max_lines=5,
elem_id="model-desc",
show_label=True,
container=True
)
def update_model_info(name):
return MODEL_OPTIONS[name]["description"]
model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
gr.Markdown("### Segmentation Options")
threshold_slider = gr.Slider(
minimum=0, maximum=1, value=0.5, step=0.01,
label="Segmentation Threshold",
interactive=True,
info="Adjust threshold for binarizing segmentation probability."
)
use_otsu = gr.Checkbox(
label="Use Otsu Threshold",
value=False,
info="Automatically select optimal threshold using Otsu's method."
)
remove_noise = gr.Checkbox(
label="Remove Small Objects",
value=False,
info="Remove small noise blobs from mask."
)
fill_holes = gr.Checkbox(
label="Fill Holes in Mask",
value=True,
info="Fill small holes inside segmented objects."
)
show_overlay = gr.Checkbox(
label="Show Overlay on Original",
value=True,
info="Display the mask overlaid on the original image."
)
show_stats = gr.Checkbox(
label="Show Area & Confidence Stats",
value=True,
info="Display segmentation statistics like area and confidence."
)
max_oligomer_size = gr.Slider(
minimum=10, maximum=1000, value=400, step=10,
label="Max Oligomer Size (px)",
interactive=True,
info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils."
)
keep_only_oligomers = gr.Checkbox(
label="Keep Only Oligomers (Remove Large Fibrils)",
value=False,
info="If enabled, only oligomers remain in the final mask."
)
submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
gr.Markdown("---")
status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg")
with gr.Row():
mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil")
stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats")
# Optional: show counts separately with big text
oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1)
fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1)
file_output = gr.File(label="Download Segmentation Mask")
submit.click(
fn=predict,
inputs=[
input_img, model_selector, threshold_slider, use_otsu,
remove_noise, fill_holes, show_overlay, show_stats,
max_oligomer_size, keep_only_oligomers
],
outputs=[
status, mask_output, prob_output, overlay_output,
file_output, stats_output, oligomer_count_txt, fibril_count_txt
]
)
demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info)
if __name__ == "__main__":
demo.launch()