Spaces:
Sleeping
Sleeping
Remove mean and standard confidence output from prediction results; update predicted mask image
671374a
| # # +++++++++++++++ Fist, Version: 1.0.0 +++++++++++++++++++++ | |
| # # Last Updated: 08 July 2025 | |
| # # Fibril Segmentation with UNet++ using Gradio | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # # ─── Configuration ───────────────────────────────────────── | |
| # CONFIG = { | |
| # "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "img_size": 512 | |
| # } | |
| # # ─── Device Setup ────────────────────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Load Model ──────────────────────────────────────────── | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name='resnet34', | |
| # encoder_depth=5, | |
| # encoder_weights='imagenet', | |
| # decoder_use_norm='batchnorm', | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # decoder_attention_type=None, | |
| # decoder_interpolation='nearest', | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ).to(device) | |
| # model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) | |
| # model.eval() | |
| # # ─── Transform Function ──────────────────────────────────── | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(CONFIG["img_size"]) | |
| # # ─── Prediction Function ─────────────────────────────────── | |
| # def predict(image): | |
| # image = image.convert("L") # Convert to grayscale | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # pred = torch.sigmoid(model(img_tensor)) | |
| # mask = (pred > 0.5).float().cpu().squeeze().numpy() | |
| # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # return mask_img | |
| # # ─── Gradio Interface ────────────────────────────────────── | |
| # demo = gr.Interface( | |
| # fn=predict, | |
| # inputs=gr.Image(type="pil", label="Upload Microscopy Image"), | |
| # outputs=gr.Image(type="pil", label="Predicted Segmentation Mask"), | |
| # title="Fibril Segmentation with Unet++", | |
| # description="Upload a grayscale microscopy image to get its predicted segmentation mask.", | |
| # allow_flagging="never", | |
| # live=False | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # # +++++++++++++++ Second Version: 1.1.0 ++++++++++++++++++++ | |
| # # Last Updated: 08 July 2025 | |
| # # Improvements: Added examples, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # # ─── Configuration ───────────────────────────────────────── | |
| # CONFIG = { | |
| # "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "img_size": 512 | |
| # } | |
| # # ─── Device Setup ────────────────────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Load Model ──────────────────────────────────────────── | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name='resnet34', | |
| # encoder_depth=5, | |
| # encoder_weights='imagenet', | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ).to(device) | |
| # model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) | |
| # model.eval() | |
| # # ─── Transform Function ──────────────────────────────────── | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(CONFIG["img_size"]) | |
| # # ─── Prediction Function ─────────────────────────────────── | |
| # def predict(image): | |
| # image = image.convert("L") # Ensure grayscale | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # pred = torch.sigmoid(model(img_tensor)) | |
| # mask = (pred > 0.5).float().cpu().squeeze().numpy() | |
| # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # return mask_img | |
| # # ─── Gradio UI (Improved) ────────────────────────────────── | |
| # examples = [ | |
| # ["examples/example1.jpg"], | |
| # ["examples/example2.jpg"], | |
| # ["examples/example3.jpg"], | |
| # ["examples/example4.jpg"], | |
| # ["examples/example5.jpg"], | |
| # ["examples/example6.jpg"], | |
| # ["examples/example7.jpg"] | |
| # ] | |
| # css = """ | |
| # .gradio-container { | |
| # max-width: 950px; | |
| # margin: auto; | |
| # } | |
| # .gr-button { | |
| # background-color: #4a90e2; | |
| # color: white; | |
| # border-radius: 5px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #357ABD; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("## 🧬 Fibril Segmentation with UNet++") | |
| # gr.Markdown("Upload a **grayscale microscopy image**, and this model will predict the **segmentation mask of fibrillar structures**.\n\nModel: ResNet34 encoder + UNet++ decoder") | |
| # with gr.Row(): | |
| # input_img = gr.Image(label="Upload Microscopy Image", type="pil") | |
| # output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil") | |
| # submit_btn = gr.Button("Segment Image") | |
| # submit_btn.click(fn=predict, inputs=input_img, outputs=output_mask) | |
| # gr.Examples( | |
| # examples=examples, | |
| # inputs=input_img, | |
| # label="Try with Example Images", | |
| # cache_examples=False | |
| # ) | |
| # # ─── Launch App ──────────────────────────────────────────── | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.2.0 ++++++++++++++++++++++ | |
| # Last Updated: 08 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # # ─── Device Setup ────────────────────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Model Configurations ────────────────────────────────── | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus" | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet" | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus" | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus" | |
| # } | |
| # } | |
| # # ─── Transform Function ──────────────────────────────────── | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # # ─── Model Loader ────────────────────────────────────────── | |
| # # def load_model(model_name): | |
| # # config = MODEL_OPTIONS[model_name] | |
| # # if config["architecture"] == "UnetPlusPlus": | |
| # # model = smp.UnetPlusPlus( | |
| # # encoder_name=config["encoder"], | |
| # # encoder_weights="imagenet", | |
| # # decoder_channels=(256, 128, 64, 32, 16), | |
| # # in_channels=1, | |
| # # classes=1, | |
| # # activation=None | |
| # # ) | |
| # # elif config["architecture"] == "Unet": | |
| # # model = smp.Unet( | |
| # # encoder_name=config["encoder"], | |
| # # encoder_weights="imagenet", | |
| # # decoder_channels=(256, 128, 64, 32, 16), | |
| # # in_channels=1, | |
| # # classes=1, | |
| # # activation=None | |
| # # ) | |
| # # else: | |
| # # raise ValueError(f"Unsupported architecture: {config['architecture']}") | |
| # # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # # model.eval() | |
| # # return model.to(device) | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # # decoder_channels=(256, 128, 64, 32, 16), # Not used in DeepLabV3Plus | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # else: | |
| # raise ValueError(f"Unsupported architecture: {config['architecture']}") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # # ─── Prediction Function ─────────────────────────────────── | |
| # def predict(image, model_name): | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # with torch.no_grad(): | |
| # pred = torch.sigmoid(model(img_tensor)) | |
| # mask = (pred > 0.5).float().cpu().squeeze().numpy() | |
| # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # return mask_img | |
| # # ─── Example Images ──────────────────────────────────────── | |
| # examples = [ | |
| # ["examples/example1.jpg"], | |
| # ["examples/example2.jpg"], | |
| # ["examples/example3.jpg"], | |
| # ["examples/example4.jpg"], | |
| # ["examples/example5.jpg"], | |
| # ["examples/example6.jpg"], | |
| # ["examples/example7.jpg"] | |
| # ] | |
| # # ─── Custom CSS ──────────────────────────────────────────── | |
| # css = """ | |
| # .gradio-container { | |
| # max-width: 950px; | |
| # margin: auto; | |
| # } | |
| # .gr-button { | |
| # background-color: #4a90e2; | |
| # color: white; | |
| # border-radius: 5px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #357ABD; | |
| # } | |
| # """ | |
| # # ─── Gradio UI ───────────────────────────────────────────── | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("## 🧬 Fibril Segmentation Interface") | |
| # gr.Markdown("Choose a model and upload a grayscale microscopy image. The model will predict the **fibrillar structure mask**.") | |
| # with gr.Row(): | |
| # model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Select Model") | |
| # with gr.Row(): | |
| # input_img = gr.Image(label="Upload Microscopy Image", type="pil") | |
| # output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil") | |
| # submit_btn = gr.Button("Segment Image") | |
| # submit_btn.click(fn=predict, inputs=[input_img, model_selector], outputs=output_mask) | |
| # gr.Examples( | |
| # examples=examples, | |
| # inputs=input_img, | |
| # label="Try with Example Images", | |
| # cache_examples=False | |
| # ) | |
| # # ─── Launch App ──────────────────────────────────────────── | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.3.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # # ─── Device Setup ────────────────────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Model Configurations ────────────────────────────────── | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet", | |
| # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus", | |
| # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| # } | |
| # } | |
| # # ─── Transform Function ──────────────────────────────────── | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # # ─── Model Loader ────────────────────────────────────────── | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # else: | |
| # raise ValueError(f"Unsupported architecture: {config['architecture']}") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # # ─── Prediction Function ─────────────────────────────────── | |
| # @torch.no_grad() | |
| # def predict(image, model_name, threshold): | |
| # if image is None: | |
| # return "❌ Please upload an image.", None, None, None | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # pred = torch.sigmoid(model(img_tensor)) | |
| # prob_map = pred.cpu().squeeze().numpy() | |
| # mask = (prob_map > threshold).astype(np.float32) | |
| # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # prob_img = Image.fromarray((prob_map * 255).astype(np.uint8)) | |
| # # Save mask temporarily for download | |
| # out_path = "predicted_mask.png" | |
| # mask_img.save(out_path) | |
| # return "✅ Done!", mask_img, prob_img, out_path | |
| # # ─── UI and Layout ───────────────────────────────────────── | |
| # css = """ | |
| # .gradio-container { | |
| # max-width: 950px; | |
| # margin: auto; | |
| # } | |
| # .gr-button { | |
| # background-color: #4a90e2; | |
| # color: white; | |
| # border-radius: 5px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #357ABD; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("## 🧬 Fibril Segmentation Interface") | |
| # gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.") | |
| # with gr.Row(): | |
| # model_selector = gr.Dropdown( | |
| # choices=list(MODEL_OPTIONS.keys()), | |
| # value="UNet++ (ResNet34)", | |
| # label="Select Model" | |
| # ) | |
| # threshold_slider = gr.Slider( | |
| # minimum=0.0, maximum=1.0, value=0.5, step=0.01, | |
| # label="Segmentation Threshold" | |
| # ) | |
| # model_info = gr.Textbox( | |
| # label="Model Description", | |
| # lines=3, | |
| # interactive=False, | |
| # value=MODEL_OPTIONS[model_selector.value]["description"] | |
| # ) | |
| # def update_model_info(model_name): | |
| # return MODEL_OPTIONS[model_name]["description"] | |
| # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # with gr.Tabs(): | |
| # with gr.Tab("🖼️ Segmentation"): | |
| # with gr.Row(): | |
| # input_img = gr.Image( | |
| # label="Upload Microscopy Image", | |
| # type="pil", | |
| # ) | |
| # output_mask = gr.Image(label="Predicted Mask", type="pil") | |
| # confidence_map = gr.Image(label="Confidence Map", type="pil") | |
| # submit_btn = gr.Button("Segment Image") | |
| # status_box = gr.Textbox(label="Status", interactive=False) | |
| # output_download = gr.File(label="Download Predicted Mask") | |
| # submit_btn.click( | |
| # fn=predict, | |
| # inputs=[input_img, model_selector, threshold_slider], | |
| # outputs=[status_box, output_mask, confidence_map, output_download] | |
| # ) | |
| # with gr.Tab("📁 Examples"): | |
| # gr.Examples( | |
| # examples=[ | |
| # ["examples/example1.jpg"], | |
| # ["examples/example2.jpg"], | |
| # ["examples/example3.jpg"], | |
| # ["examples/example4.jpg"], | |
| # ["examples/example5.jpg"], | |
| # ["examples/example6.jpg"], | |
| # ["examples/example7.jpg"] | |
| # ], | |
| # inputs=input_img, | |
| # label="Try with Example Images", | |
| # cache_examples=False | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.4.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # # ─── Device Setup ────────────────────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Model Configurations ────────────────────────────────── | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet", | |
| # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus", | |
| # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| # } | |
| # } | |
| # # ─── Transform Function ──────────────────────────────────── | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # # ─── Model Loader ────────────────────────────────────────── | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], | |
| # encoder_weights="imagenet", | |
| # in_channels=1, | |
| # classes=1, | |
| # activation=None | |
| # ) | |
| # else: | |
| # raise ValueError(f"Unsupported architecture: {config['architecture']}") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # # ─── Prediction Function ─────────────────────────────────── | |
| # @torch.no_grad() | |
| # def predict(image, model_name, threshold): | |
| # if image is None: | |
| # return "❌ Please upload an image.", None, None, None | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # pred = torch.sigmoid(model(img_tensor)) | |
| # prob_map = pred.cpu().squeeze().numpy() | |
| # mask = (prob_map > threshold).astype(np.float32) | |
| # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) | |
| # prob_img = Image.fromarray((prob_map * 255).astype(np.uint8)) | |
| # out_path = "predicted_mask.png" | |
| # mask_img.save(out_path) | |
| # return "✅ Done!", mask_img, prob_img, out_path | |
| # # ─── UI and Layout ───────────────────────────────────────── | |
| # css = """ | |
| # .gradio-container { | |
| # max-width: 1100px; | |
| # margin: auto; | |
| # } | |
| # .gr-button { | |
| # background-color: #4a90e2; | |
| # color: white; | |
| # border-radius: 5px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #357ABD; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("## 🧬 Fibril Segmentation Interface") | |
| # gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.") | |
| # with gr.Row(): | |
| # # Left Panel – Examples and Upload | |
| # with gr.Column(scale=1): | |
| # gr.Markdown("### 📁 Example Images") | |
| # input_img = gr.Image( | |
| # label="Upload Microscopy Image", | |
| # type="pil", | |
| # interactive=True | |
| # ) | |
| # gr.Examples( | |
| # examples=[ | |
| # ["examples/example1.jpg"], | |
| # ["examples/example2.jpg"], | |
| # ["examples/example3.jpg"], | |
| # ["examples/example4.jpg"], | |
| # ["examples/example5.jpg"], | |
| # ["examples/example6.jpg"], | |
| # ["examples/example7.jpg"] | |
| # ], | |
| # inputs=input_img, | |
| # label="Try with Example Images", | |
| # cache_examples=False | |
| # ) | |
| # # Right Panel – Controls & Output | |
| # with gr.Column(scale=2): | |
| # with gr.Row(): | |
| # model_selector = gr.Dropdown( | |
| # choices=list(MODEL_OPTIONS.keys()), | |
| # value="UNet++ (ResNet34)", | |
| # label="Select Model" | |
| # ) | |
| # threshold_slider = gr.Slider( | |
| # minimum=0.0, maximum=1.0, value=0.5, step=0.01, | |
| # label="Segmentation Threshold" | |
| # ) | |
| # model_info = gr.Textbox( | |
| # label="Model Description", | |
| # lines=3, | |
| # interactive=False, | |
| # value=MODEL_OPTIONS["UNet++ (ResNet34)"]["description"] | |
| # ) | |
| # def update_model_info(model_name): | |
| # return MODEL_OPTIONS[model_name]["description"] | |
| # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # with gr.Row(): | |
| # output_mask = gr.Image(label="Predicted Mask", type="pil") | |
| # confidence_map = gr.Image(label="Confidence Map", type="pil") | |
| # submit_btn = gr.Button("Segment Image") | |
| # status_box = gr.Textbox(label="Status", interactive=False) | |
| # output_download = gr.File(label="Download Predicted Mask") | |
| # submit_btn.click( | |
| # fn=predict, | |
| # inputs=[input_img, model_selector, threshold_slider], | |
| # outputs=[status_box, output_mask, confidence_map, output_download] | |
| # ) | |
| # # ─── Launch App ──────────────────────────────────────────── | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.5.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # from skimage import filters, measure, morphology | |
| # # ─── Device Setup ───────────────────────────── | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # print(f"✅ Using device: {device}") | |
| # # ─── Model Configurations ───────────────────── | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet", | |
| # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus", | |
| # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| # } | |
| # } | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # in_channels=1, classes=1, activation=None) | |
| # else: | |
| # raise ValueError("Unsupported architecture.") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # # ─── Prediction Logic ───────────────────────── | |
| # @torch.no_grad() | |
| # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats): | |
| # if image is None: | |
| # return "❌ Please upload an image.", None, None, None, "", "" | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() | |
| # if use_otsu: | |
| # threshold = filters.threshold_otsu(pred) | |
| # binary_mask = (pred > threshold).astype(np.float32) | |
| # # Post-process mask | |
| # if remove_noise: | |
| # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) | |
| # if fill_holes: | |
| # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) | |
| # binary_mask = binary_mask.astype(np.float32) | |
| # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) | |
| # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) | |
| # if show_overlay: | |
| # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") | |
| # overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) | |
| # else: | |
| # overlay_img = None | |
| # stats_text = "" | |
| # if show_stats: | |
| # labeled_mask = measure.label(binary_mask) | |
| # area = np.sum(binary_mask) | |
| # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 | |
| # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 | |
| # stats_text = ( | |
| # f"🧮 Stats:\n - Area (px): {area:.0f}\n" | |
| # f" - Objects: {labeled_mask.max()}\n" | |
| # f" - Mean Conf: {mean_conf:.3f}\n" | |
| # f" - Std Conf: {std_conf:.3f}" | |
| # ) | |
| # mask_img.save("predicted_mask.png") | |
| # return "✅ Done!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text | |
| # # ─── UI ─────────────────────────────────────── | |
| # css = """ | |
| # .gradio-container { max-width: 1100px; margin: auto; } | |
| # .gr-button { background-color: #4a90e2; color: white; border-radius: 5px; } | |
| # .gr-button:hover { background-color: #357ABD; } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("## 🧬 Fibril Segmentation Interface") | |
| # gr.Markdown("Upload a grayscale microscopy image and choose options to segment fibrillar structures.") | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # input_img = gr.Image(label="Upload Image", type="pil", interactive=True) | |
| # gr.Examples( | |
| # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], | |
| # inputs=input_img, | |
| # label="📁 Try Example Images", cache_examples=False | |
| # ) | |
| # with gr.Column(scale=2): | |
| # model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Model") | |
| # threshold_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Segmentation Threshold") | |
| # use_otsu = gr.Checkbox(label="Use Otsu Threshold", value=False) | |
| # remove_noise = gr.Checkbox(label="Remove Small Objects", value=True) | |
| # fill_holes = gr.Checkbox(label="Fill Holes in Mask", value=True) | |
| # show_overlay = gr.Checkbox(label="Show Overlay on Original", value=True) | |
| # show_stats = gr.Checkbox(label="Show Area & Confidence Stats", value=True) | |
| # model_info = gr.Textbox(label="Model Description", interactive=False, lines=2) | |
| # def update_model_info(name): return MODEL_OPTIONS[name]["description"] | |
| # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # submit = gr.Button("Segment Image") | |
| # status = gr.Textbox(label="Status", interactive=False) | |
| # with gr.Row(): | |
| # mask_output = gr.Image(label="Binary Mask") | |
| # prob_output = gr.Image(label="Confidence Map") | |
| # overlay_output = gr.Image(label="Overlay") | |
| # stats_output = gr.Textbox(label="Segmentation Stats", lines=4) | |
| # file_output = gr.File(label="Download Mask") | |
| # submit.click( | |
| # fn=predict, | |
| # inputs=[ | |
| # input_img, model_selector, threshold_slider, use_otsu, | |
| # remove_noise, fill_holes, show_overlay, show_stats | |
| # ], | |
| # outputs=[ | |
| # status, mask_output, prob_output, overlay_output, | |
| # file_output, stats_output | |
| # ] | |
| # ) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.6.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # from skimage import filters, measure, morphology | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet", | |
| # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus", | |
| # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| # } | |
| # } | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # in_channels=1, classes=1, activation=None) | |
| # else: | |
| # raise ValueError("Unsupported architecture.") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # @torch.no_grad() | |
| # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats): | |
| # if image is None: | |
| # return "❌ Please upload an image.", None, None, None, "", "" | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() | |
| # if use_otsu: | |
| # threshold = filters.threshold_otsu(pred) | |
| # binary_mask = (pred > threshold).astype(np.float32) | |
| # if remove_noise: | |
| # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) | |
| # if fill_holes: | |
| # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) | |
| # binary_mask = binary_mask.astype(np.float32) | |
| # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) | |
| # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) | |
| # if show_overlay: | |
| # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") | |
| # overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) | |
| # else: | |
| # overlay_img = None | |
| # stats_text = "" | |
| # if show_stats: | |
| # labeled_mask = measure.label(binary_mask) | |
| # area = np.sum(binary_mask) | |
| # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 | |
| # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 | |
| # stats_text = ( | |
| # f"🧮 Stats:\n - Area (px): {area:.0f}\n" | |
| # f" - Objects: {labeled_mask.max()}\n" | |
| # f" - Mean Conf: {mean_conf:.3f}\n" | |
| # f" - Std Conf: {std_conf:.3f}" | |
| # ) | |
| # mask_img.save("predicted_mask.png") | |
| # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text | |
| # css = """ | |
| # body { | |
| # background: #f9fafb; | |
| # color: #2c3e50; | |
| # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| # } | |
| # h1, h2, h3 { | |
| # color: #34495e; | |
| # margin-bottom: 0.2em; | |
| # } | |
| # .gradio-container { | |
| # max-width: 1100px; | |
| # margin: 1.5rem auto; | |
| # padding: 1rem 2rem; | |
| # } | |
| # .gr-button { | |
| # background-color: #0078d7; | |
| # color: white; | |
| # font-weight: 600; | |
| # border-radius: 8px; | |
| # padding: 12px 25px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #005a9e; | |
| # } | |
| # .gr-slider label, .gr-checkbox label { | |
| # font-weight: 600; | |
| # color: #34495e; | |
| # } | |
| # .gr-image input[type="file"] { | |
| # border-radius: 8px; | |
| # } | |
| # .gr-file label { | |
| # font-weight: 600; | |
| # } | |
| # .gr-textbox textarea { | |
| # font-family: monospace; | |
| # font-size: 0.9rem; | |
| # background: #ecf0f1; | |
| # border-radius: 6px; | |
| # padding: 8px; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>") | |
| # gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>") | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) | |
| # gr.Examples( | |
| # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], | |
| # inputs=input_img, | |
| # label="📁 Try Example Images", | |
| # cache_examples=False, | |
| # elem_id="examples" | |
| # ) | |
| # with gr.Column(scale=1): | |
| # model_selector = gr.Dropdown( | |
| # choices=list(MODEL_OPTIONS.keys()), | |
| # value="UNet++ (ResNet34)", | |
| # label="Select Model", | |
| # interactive=True | |
| # ) | |
| # model_info = gr.Textbox( | |
| # label="Model Description", | |
| # interactive=False, | |
| # lines=3, | |
| # max_lines=5, | |
| # elem_id="model-desc", | |
| # show_label=True, | |
| # container=True | |
| # ) | |
| # def update_model_info(name): | |
| # return MODEL_OPTIONS[name]["description"] | |
| # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # gr.Markdown("### Segmentation Options") | |
| # threshold_slider = gr.Slider( | |
| # minimum=0, maximum=1, value=0.5, step=0.01, | |
| # label="Segmentation Threshold", | |
| # interactive=True, | |
| # info="Adjust threshold for binarizing segmentation probability." | |
| # ) | |
| # use_otsu = gr.Checkbox( | |
| # label="Use Otsu Threshold", | |
| # value=False, | |
| # info="Automatically select optimal threshold using Otsu's method." | |
| # ) | |
| # remove_noise = gr.Checkbox( | |
| # label="Remove Small Objects", | |
| # value=False, | |
| # info="Remove small noise blobs from mask." | |
| # ) | |
| # fill_holes = gr.Checkbox( | |
| # label="Fill Holes in Mask", | |
| # value=True, | |
| # info="Fill small holes inside segmented objects." | |
| # ) | |
| # show_overlay = gr.Checkbox( | |
| # label="Show Overlay on Original", | |
| # value=True, | |
| # info="Display the mask overlaid on the original image." | |
| # ) | |
| # show_stats = gr.Checkbox( | |
| # label="Show Area & Confidence Stats", | |
| # value=True, | |
| # info="Display segmentation statistics like area and confidence." | |
| # ) | |
| # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") | |
| # gr.Markdown("---") | |
| # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") | |
| # with gr.Row(): | |
| # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") | |
| # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") | |
| # overlay_output = gr.Image(label="Overlay", interactive=False, type="pil") | |
| # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=6, elem_id="stats") | |
| # file_output = gr.File(label="Download Segmentation Mask") | |
| # submit.click( | |
| # fn=predict, | |
| # inputs=[ | |
| # input_img, model_selector, threshold_slider, use_otsu, | |
| # remove_noise, fill_holes, show_overlay, show_stats | |
| # ], | |
| # outputs=[ | |
| # status, mask_output, prob_output, overlay_output, | |
| # file_output, stats_output | |
| # ] | |
| # ) | |
| # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.7.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling | |
| # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
| # import os | |
| # import torch | |
| # import numpy as np | |
| # from PIL import Image, ImageDraw, ImageFont | |
| # import albumentations as A | |
| # from albumentations.pytorch import ToTensorV2 | |
| # import segmentation_models_pytorch as smp | |
| # import gradio as gr | |
| # from skimage import filters, measure, morphology | |
| # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # MODEL_OPTIONS = { | |
| # "UNet++ (ResNet34)": { | |
| # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| # }, | |
| # "UNet (ResNet34)": { | |
| # "path": "./model/unet_fibril_seg_model.pth", | |
| # "encoder": "resnet34", | |
| # "architecture": "Unet", | |
| # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| # }, | |
| # "UNet++ (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "UnetPlusPlus", | |
| # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| # }, | |
| # "DeepLabV3Plus (efficientnet-b3)": { | |
| # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| # "encoder": "efficientnet-b3", | |
| # "architecture": "DeepLabV3Plus", | |
| # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| # } | |
| # } | |
| # def get_transform(size): | |
| # return A.Compose([ | |
| # A.Resize(size, size), | |
| # A.Normalize(mean=(0.5,), std=(0.5,)), | |
| # ToTensorV2() | |
| # ]) | |
| # transform = get_transform(512) | |
| # model_cache = {} | |
| # def load_model(model_name): | |
| # if model_name in model_cache: | |
| # return model_cache[model_name] | |
| # config = MODEL_OPTIONS[model_name] | |
| # if config["architecture"] == "UnetPlusPlus": | |
| # model = smp.UnetPlusPlus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "Unet": | |
| # model = smp.Unet( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # decoder_channels=(256, 128, 64, 32, 16), | |
| # in_channels=1, classes=1, activation=None) | |
| # elif config["architecture"] == "DeepLabV3Plus": | |
| # model = smp.DeepLabV3Plus( | |
| # encoder_name=config["encoder"], encoder_weights="imagenet", | |
| # in_channels=1, classes=1, activation=None) | |
| # else: | |
| # raise ValueError("Unsupported architecture.") | |
| # model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| # model.eval() | |
| # model_cache[model_name] = model.to(device) | |
| # return model_cache[model_name] | |
| # def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size): | |
| # """ | |
| # Draws numbers on the overlay image labeling oligomers and fibrils. | |
| # Oligomers labeled as O1, O2, ... | |
| # Fibrils labeled as F1, F2, ... | |
| # """ | |
| # overlay = orig_img.convert("RGB").copy() | |
| # draw = ImageDraw.Draw(overlay) | |
| # # Try to get a nice font; fallback to default if not available | |
| # try: | |
| # font = ImageFont.truetype("arial.ttf", 18) | |
| # except IOError: | |
| # font = ImageFont.load_default() | |
| # labeled_mask = measure.label(binary_mask) | |
| # regions = measure.regionprops(labeled_mask) | |
| # oligomer_count = 0 | |
| # fibril_count = 0 | |
| # for region in regions: | |
| # area = region.area | |
| # centroid = region.centroid # (row, col) | |
| # x, y = int(centroid[1]), int(centroid[0]) | |
| # if area <= max_oligomer_size: | |
| # oligomer_count += 1 | |
| # label_text = f"O{oligomer_count}" | |
| # label_color = (0, 255, 0) # Green for oligomers | |
| # else: | |
| # fibril_count += 1 | |
| # label_text = f"F{fibril_count}" | |
| # label_color = (255, 0, 0) # Red for fibrils | |
| # # Draw circle around centroid for visibility | |
| # r = 12 | |
| # draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2) | |
| # # Draw label text | |
| # bbox = draw.textbbox((0, 0), label_text, font=font) | |
| # text_width = bbox[2] - bbox[0] | |
| # text_height = bbox[3] - bbox[1] | |
| # text_pos = (x - text_width // 2, y - text_height // 2) | |
| # draw.text(text_pos, label_text, fill=label_color, font=font) | |
| # return overlay, oligomer_count, fibril_count | |
| # @torch.no_grad() | |
| # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers): | |
| # if image is None: | |
| # return "❌ Please upload an image.", None, None, None, "", "", "", "" | |
| # image = image.convert("L") | |
| # img_np = np.array(image) | |
| # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| # model = load_model(model_name) | |
| # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() | |
| # if use_otsu: | |
| # threshold = filters.threshold_otsu(pred) | |
| # binary_mask = (pred > threshold).astype(np.float32) | |
| # if remove_noise: | |
| # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) | |
| # if fill_holes: | |
| # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) | |
| # binary_mask = binary_mask.astype(np.float32) | |
| # # If user wants to keep only oligomers, remove large fibrils from mask | |
| # if keep_only_oligomers: | |
| # labeled_mask = measure.label(binary_mask) | |
| # regions = measure.regionprops(labeled_mask) | |
| # filtered_mask = np.zeros_like(binary_mask) | |
| # for region in regions: | |
| # if region.area <= max_oligomer_size: | |
| # filtered_mask[labeled_mask == region.label] = 1 | |
| # binary_mask = filtered_mask.astype(np.float32) | |
| # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) | |
| # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) | |
| # overlay_img = None | |
| # oligomer_count = 0 | |
| # fibril_count = 0 | |
| # if show_overlay: | |
| # # Resize mask to original image size | |
| # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") | |
| # # Blend original and mask | |
| # base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) | |
| # # Draw labels on overlay | |
| # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size) | |
| # else: | |
| # overlay_img = None | |
| # stats_text = "" | |
| # if show_stats: | |
| # labeled_mask = measure.label(binary_mask) | |
| # area = np.sum(binary_mask) | |
| # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 | |
| # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 | |
| # total_objects = labeled_mask.max() | |
| # # Count oligomers and fibrils | |
| # oligomers = 0 | |
| # fibrils = 0 | |
| # for region in measure.regionprops(labeled_mask): | |
| # if region.area <= max_oligomer_size: | |
| # oligomers += 1 | |
| # else: | |
| # fibrils += 1 | |
| # stats_text = ( | |
| # f"🧮 Stats:\n" | |
| # f" - Area (px): {area:.0f}\n" | |
| # f" - Total Objects: {total_objects}\n" | |
| # f" - Oligomers (small): {oligomers}\n" | |
| # f" - Fibrils (large): {fibrils}\n" | |
| # f" - Mean Confidence: {mean_conf:.3f}\n" | |
| # f" - Std Confidence: {std_conf:.3f}" | |
| # ) | |
| # mask_img.save("predicted_mask.png") | |
| # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count) | |
| # css = """ | |
| # body { | |
| # background: #f9fafb; | |
| # color: #2c3e50; | |
| # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| # } | |
| # h1, h2, h3 { | |
| # color: #34495e; | |
| # margin-bottom: 0.2em; | |
| # } | |
| # .gradio-container { | |
| # max-width: 1100px; | |
| # margin: 1.5rem auto; | |
| # padding: 1rem 2rem; | |
| # } | |
| # .gr-button { | |
| # background-color: #0078d7; | |
| # color: white; | |
| # font-weight: 600; | |
| # border-radius: 8px; | |
| # padding: 12px 25px; | |
| # } | |
| # .gr-button:hover { | |
| # background-color: #005a9e; | |
| # } | |
| # .gr-slider label, .gr-checkbox label { | |
| # font-weight: 600; | |
| # color: #34495e; | |
| # } | |
| # .gr-image input[type="file"] { | |
| # border-radius: 8px; | |
| # } | |
| # .gr-file label { | |
| # font-weight: 600; | |
| # } | |
| # .gr-textbox textarea { | |
| # font-family: monospace; | |
| # font-size: 0.9rem; | |
| # background: #ecf0f1; | |
| # border-radius: 6px; | |
| # padding: 8px; | |
| # } | |
| # """ | |
| # with gr.Blocks(css=css) as demo: | |
| # gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>") | |
| # gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>") | |
| # with gr.Row(): | |
| # with gr.Column(scale=1): | |
| # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) | |
| # gr.Examples( | |
| # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], | |
| # inputs=input_img, | |
| # label="📁 Try Example Images", | |
| # cache_examples=False, | |
| # elem_id="examples" | |
| # ) | |
| # with gr.Column(scale=1): | |
| # model_selector = gr.Dropdown( | |
| # choices=list(MODEL_OPTIONS.keys()), | |
| # value="UNet++ (ResNet34)", | |
| # label="Select Model", | |
| # interactive=True | |
| # ) | |
| # model_info = gr.Textbox( | |
| # label="Model Description", | |
| # interactive=False, | |
| # lines=3, | |
| # max_lines=5, | |
| # elem_id="model-desc", | |
| # show_label=True, | |
| # container=True | |
| # ) | |
| # def update_model_info(name): | |
| # return MODEL_OPTIONS[name]["description"] | |
| # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # gr.Markdown("### Segmentation Options") | |
| # threshold_slider = gr.Slider( | |
| # minimum=0, maximum=1, value=0.5, step=0.01, | |
| # label="Segmentation Threshold", | |
| # interactive=True, | |
| # info="Adjust threshold for binarizing segmentation probability." | |
| # ) | |
| # use_otsu = gr.Checkbox( | |
| # label="Use Otsu Threshold", | |
| # value=False, | |
| # info="Automatically select optimal threshold using Otsu's method." | |
| # ) | |
| # remove_noise = gr.Checkbox( | |
| # label="Remove Small Objects", | |
| # value=False, | |
| # info="Remove small noise blobs from mask." | |
| # ) | |
| # fill_holes = gr.Checkbox( | |
| # label="Fill Holes in Mask", | |
| # value=True, | |
| # info="Fill small holes inside segmented objects." | |
| # ) | |
| # show_overlay = gr.Checkbox( | |
| # label="Show Overlay on Original", | |
| # value=True, | |
| # info="Display the mask overlaid on the original image." | |
| # ) | |
| # show_stats = gr.Checkbox( | |
| # label="Show Area & Confidence Stats", | |
| # value=True, | |
| # info="Display segmentation statistics like area and confidence." | |
| # ) | |
| # max_oligomer_size = gr.Slider( | |
| # minimum=10, maximum=1000, value=400, step=10, | |
| # label="Max Oligomer Size (px)", | |
| # interactive=True, | |
| # info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils." | |
| # ) | |
| # keep_only_oligomers = gr.Checkbox( | |
| # label="Keep Only Oligomers (Remove Large Fibrils)", | |
| # value=False, | |
| # info="If enabled, only oligomers remain in the final mask." | |
| # ) | |
| # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") | |
| # gr.Markdown("---") | |
| # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") | |
| # with gr.Row(): | |
| # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") | |
| # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") | |
| # overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil") | |
| # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats") | |
| # # Optional: show counts separately with big text | |
| # oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1) | |
| # fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1) | |
| # file_output = gr.File(label="Download Segmentation Mask") | |
| # submit.click( | |
| # fn=predict, | |
| # inputs=[ | |
| # input_img, model_selector, threshold_slider, use_otsu, | |
| # remove_noise, fill_holes, show_overlay, show_stats, | |
| # max_oligomer_size, keep_only_oligomers | |
| # ], | |
| # outputs=[ | |
| # status, mask_output, prob_output, overlay_output, | |
| # file_output, stats_output, oligomer_count_txt, fibril_count_txt | |
| # ] | |
| # ) | |
| # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| # if __name__ == "__main__": | |
| # demo.launch() | |
| # +++++++++++++++ Final Version: 1.8.0 ++++++++++++++++++++++ | |
| # Last Updated: 10 July 2025 | |
| # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling | |
| # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ | |
| import os | |
| import torch | |
| import numpy as np | |
| from PIL import Image, ImageDraw, ImageFont | |
| import albumentations as A | |
| from albumentations.pytorch import ToTensorV2 | |
| import segmentation_models_pytorch as smp | |
| import gradio as gr | |
| from skimage import filters, measure, morphology | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL_OPTIONS = { | |
| "UNet++ (ResNet34)": { | |
| "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| "encoder": "resnet34", | |
| "architecture": "UnetPlusPlus", | |
| "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." | |
| }, | |
| "UNet (ResNet34)": { | |
| "path": "./model/unet_fibril_seg_model.pth", | |
| "encoder": "resnet34", | |
| "architecture": "Unet", | |
| "description": "Classic UNet with ResNet34 encoder — fast and lightweight." | |
| }, | |
| "UNet++ (efficientnet-b3)": { | |
| "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", | |
| "encoder": "efficientnet-b3", | |
| "architecture": "UnetPlusPlus", | |
| "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." | |
| }, | |
| "DeepLabV3Plus (efficientnet-b3)": { | |
| "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", | |
| "encoder": "efficientnet-b3", | |
| "architecture": "DeepLabV3Plus", | |
| "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." | |
| } | |
| } | |
| def get_transform(size): | |
| return A.Compose([ | |
| A.Resize(size, size), | |
| A.Normalize(mean=(0.5,), std=(0.5,)), | |
| ToTensorV2() | |
| ]) | |
| transform = get_transform(512) | |
| model_cache = {} | |
| def load_model(model_name): | |
| if model_name in model_cache: | |
| return model_cache[model_name] | |
| config = MODEL_OPTIONS[model_name] | |
| if config["architecture"] == "UnetPlusPlus": | |
| model = smp.UnetPlusPlus( | |
| encoder_name=config["encoder"], encoder_weights="imagenet", | |
| decoder_channels=(256, 128, 64, 32, 16), | |
| in_channels=1, classes=1, activation=None) | |
| elif config["architecture"] == "Unet": | |
| model = smp.Unet( | |
| encoder_name=config["encoder"], encoder_weights="imagenet", | |
| decoder_channels=(256, 128, 64, 32, 16), | |
| in_channels=1, classes=1, activation=None) | |
| elif config["architecture"] == "DeepLabV3Plus": | |
| model = smp.DeepLabV3Plus( | |
| encoder_name=config["encoder"], encoder_weights="imagenet", | |
| in_channels=1, classes=1, activation=None) | |
| else: | |
| raise ValueError("Unsupported architecture.") | |
| model.load_state_dict(torch.load(config["path"], map_location=device)) | |
| model.eval() | |
| model_cache[model_name] = model.to(device) | |
| return model_cache[model_name] | |
| def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size, fibril_length_thresh=100): | |
| """ | |
| Draw labels on the overlay image based on circularity and length. | |
| Oligomers: High circularity (circle-like). | |
| Fibrils: Long objects (based on major axis length). | |
| fibril_length_thresh: Length threshold to define fibrils. | |
| """ | |
| overlay = orig_img.convert("RGB").copy() | |
| draw = ImageDraw.Draw(overlay) | |
| try: | |
| font = ImageFont.truetype("arial.ttf", 18) | |
| except IOError: | |
| font = ImageFont.load_default() | |
| labeled_mask = measure.label(binary_mask) | |
| regions = measure.regionprops(labeled_mask) | |
| oligomer_count = 0 | |
| fibril_count = 0 | |
| for region in regions: | |
| area = region.area | |
| perimeter = region.perimeter if region.perimeter > 0 else 1 # prevent div by zero | |
| circularity = 4 * np.pi * area / (perimeter ** 2) | |
| major_length = region.major_axis_length | |
| centroid = region.centroid | |
| x, y = int(centroid[1]), int(centroid[0]) | |
| # Thresholds to tune | |
| circularity_thresh = 1 # close to circle | |
| # max_oligomer_size can still be used for area-based filtering if desired | |
| # Classification logic | |
| if circularity >= circularity_thresh and area <= max_oligomer_size: | |
| oligomer_count += 1 | |
| label_text = f"O{oligomer_count}" | |
| label_color = (0, 255, 0) # Green for oligomers | |
| elif major_length >= fibril_length_thresh: | |
| fibril_count += 1 | |
| label_text = f"F{fibril_count}" | |
| label_color = (255, 0, 0) # Red for fibrils | |
| else: | |
| # If neither circular nor long, you can optionally skip labeling or classify as fibril | |
| fibril_count += 1 | |
| label_text = f"F{fibril_count}" | |
| label_color = (255, 0, 0) | |
| r = 12 | |
| draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2) | |
| bbox = draw.textbbox((0, 0), label_text, font=font) | |
| text_width = bbox[2] - bbox[0] | |
| text_height = bbox[3] - bbox[1] | |
| text_pos = (x - text_width // 2, y - text_height // 2) | |
| draw.text(text_pos, label_text, fill=label_color, font=font) | |
| return overlay, oligomer_count, fibril_count | |
| def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers): | |
| if image is None: | |
| return "❌ Please upload an image.", None, None, None, "", "", "", "" | |
| image = image.convert("L") | |
| img_np = np.array(image) | |
| img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) | |
| model = load_model(model_name) | |
| pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() | |
| if use_otsu: | |
| threshold = filters.threshold_otsu(pred) | |
| binary_mask = (pred > threshold).astype(np.float32) | |
| if remove_noise: | |
| binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) | |
| if fill_holes: | |
| binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) | |
| binary_mask = binary_mask.astype(np.float32) | |
| # If user wants to keep only oligomers, remove large fibrils from mask | |
| if keep_only_oligomers: | |
| labeled_mask = measure.label(binary_mask) | |
| regions = measure.regionprops(labeled_mask) | |
| filtered_mask = np.zeros_like(binary_mask) | |
| for region in regions: | |
| if region.area <= max_oligomer_size: | |
| filtered_mask[labeled_mask == region.label] = 1 | |
| binary_mask = filtered_mask.astype(np.float32) | |
| mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) | |
| prob_img = Image.fromarray((pred * 255).astype(np.uint8)) | |
| overlay_img = None | |
| oligomer_count = 0 | |
| fibril_count = 0 | |
| if show_overlay: | |
| # Resize mask to original image size | |
| mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") | |
| # Blend original and mask | |
| base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) | |
| # Draw labels on overlay | |
| # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size) | |
| overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size, fibril_length_thresh=100) | |
| else: | |
| overlay_img = None | |
| stats_text = "" | |
| if show_stats: | |
| labeled_mask = measure.label(binary_mask) | |
| area = np.sum(binary_mask) | |
| mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 | |
| std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 | |
| total_objects = labeled_mask.max() | |
| # Count oligomers and fibrils | |
| oligomers = 0 | |
| fibrils = 0 | |
| for region in measure.regionprops(labeled_mask): | |
| if region.area <= max_oligomer_size: | |
| oligomers += 1 | |
| else: | |
| fibrils += 1 | |
| stats_text = ( | |
| f"🧮 Stats:\n" | |
| f" - Area (px): {area:.0f}\n" | |
| f" - Total Objects: {total_objects}\n" | |
| f" - Oligomers (small): {oligomers}\n" | |
| f" - Fibrils (large): {fibrils}\n" | |
| # f" - Mean Confidence: {mean_conf:.3f}\n" | |
| # f" - Std Confidence: {std_conf:.3f}" | |
| ) | |
| mask_img.save("predicted_mask.png") | |
| return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count) | |
| css = """ | |
| body { | |
| background: #f9fafb; | |
| color: #2c3e50; | |
| font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; | |
| } | |
| h1, h2, h3 { | |
| color: #34495e; | |
| margin-bottom: 0.2em; | |
| } | |
| .gradio-container { | |
| max-width: 1100px; | |
| margin: 1.5rem auto; | |
| padding: 1rem 2rem; | |
| } | |
| .gr-button { | |
| background-color: #0078d7; | |
| color: white; | |
| font-weight: 600; | |
| border-radius: 8px; | |
| padding: 12px 25px; | |
| } | |
| .gr-button:hover { | |
| background-color: #005a9e; | |
| } | |
| .gr-slider label, .gr-checkbox label { | |
| font-weight: 600; | |
| color: #34495e; | |
| } | |
| .gr-image input[type="file"] { | |
| border-radius: 8px; | |
| } | |
| .gr-file label { | |
| font-weight: 600; | |
| } | |
| .gr-textbox textarea { | |
| font-family: monospace; | |
| font-size: 0.9rem; | |
| background: #ecf0f1; | |
| border-radius: 6px; | |
| padding: 8px; | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>") | |
| gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) | |
| gr.Examples( | |
| examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], | |
| inputs=input_img, | |
| label="📁 Try Example Images", | |
| cache_examples=False, | |
| elem_id="examples" | |
| ) | |
| with gr.Column(scale=1): | |
| model_selector = gr.Dropdown( | |
| choices=list(MODEL_OPTIONS.keys()), | |
| value="UNet++ (ResNet34)", | |
| label="Select Model", | |
| interactive=True | |
| ) | |
| model_info = gr.Textbox( | |
| label="Model Description", | |
| interactive=False, | |
| lines=3, | |
| max_lines=5, | |
| elem_id="model-desc", | |
| show_label=True, | |
| container=True | |
| ) | |
| def update_model_info(name): | |
| return MODEL_OPTIONS[name]["description"] | |
| model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| gr.Markdown("### Segmentation Options") | |
| threshold_slider = gr.Slider( | |
| minimum=0, maximum=1, value=0.5, step=0.01, | |
| label="Segmentation Threshold", | |
| interactive=True, | |
| info="Adjust threshold for binarizing segmentation probability." | |
| ) | |
| use_otsu = gr.Checkbox( | |
| label="Use Otsu Threshold", | |
| value=False, | |
| info="Automatically select optimal threshold using Otsu's method." | |
| ) | |
| remove_noise = gr.Checkbox( | |
| label="Remove Small Objects", | |
| value=False, | |
| info="Remove small noise blobs from mask." | |
| ) | |
| fill_holes = gr.Checkbox( | |
| label="Fill Holes in Mask", | |
| value=True, | |
| info="Fill small holes inside segmented objects." | |
| ) | |
| show_overlay = gr.Checkbox( | |
| label="Show Overlay on Original", | |
| value=True, | |
| info="Display the mask overlaid on the original image." | |
| ) | |
| show_stats = gr.Checkbox( | |
| label="Show Area & Confidence Stats", | |
| value=True, | |
| info="Display segmentation statistics like area and confidence." | |
| ) | |
| max_oligomer_size = gr.Slider( | |
| minimum=10, maximum=1000, value=400, step=10, | |
| label="Max Oligomer Size (px)", | |
| interactive=True, | |
| info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils." | |
| ) | |
| keep_only_oligomers = gr.Checkbox( | |
| label="Keep Only Oligomers (Remove Large Fibrils)", | |
| value=False, | |
| info="If enabled, only oligomers remain in the final mask." | |
| ) | |
| submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") | |
| gr.Markdown("---") | |
| status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") | |
| with gr.Row(): | |
| mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") | |
| prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") | |
| overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil") | |
| stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats") | |
| # Optional: show counts separately with big text | |
| oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1) | |
| fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1) | |
| file_output = gr.File(label="Download Segmentation Mask") | |
| submit.click( | |
| fn=predict, | |
| inputs=[ | |
| input_img, model_selector, threshold_slider, use_otsu, | |
| remove_noise, fill_holes, show_overlay, show_stats, | |
| max_oligomer_size, keep_only_oligomers | |
| ], | |
| outputs=[ | |
| status, mask_output, prob_output, overlay_output, | |
| file_output, stats_output, oligomer_count_txt, fibril_count_txt | |
| ] | |
| ) | |
| demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) | |
| if __name__ == "__main__": | |
| demo.launch() | |