# # +++++++++++++++ Fist, Version: 1.0.0 +++++++++++++++++++++ # # Last Updated: 08 July 2025 # # Fibril Segmentation with UNet++ using Gradio # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # # ─── Configuration ───────────────────────────────────────── # CONFIG = { # "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "img_size": 512 # } # # ─── Device Setup ────────────────────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Load Model ──────────────────────────────────────────── # model = smp.UnetPlusPlus( # encoder_name='resnet34', # encoder_depth=5, # encoder_weights='imagenet', # decoder_use_norm='batchnorm', # decoder_channels=(256, 128, 64, 32, 16), # decoder_attention_type=None, # decoder_interpolation='nearest', # in_channels=1, # classes=1, # activation=None # ).to(device) # model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) # model.eval() # # ─── Transform Function ──────────────────────────────────── # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(CONFIG["img_size"]) # # ─── Prediction Function ─────────────────────────────────── # def predict(image): # image = image.convert("L") # Convert to grayscale # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # with torch.no_grad(): # pred = torch.sigmoid(model(img_tensor)) # mask = (pred > 0.5).float().cpu().squeeze().numpy() # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # return mask_img # # ─── Gradio Interface ────────────────────────────────────── # demo = gr.Interface( # fn=predict, # inputs=gr.Image(type="pil", label="Upload Microscopy Image"), # outputs=gr.Image(type="pil", label="Predicted Segmentation Mask"), # title="Fibril Segmentation with Unet++", # description="Upload a grayscale microscopy image to get its predicted segmentation mask.", # allow_flagging="never", # live=False # ) # if __name__ == "__main__": # demo.launch() # # +++++++++++++++ Second Version: 1.1.0 ++++++++++++++++++++ # # Last Updated: 08 July 2025 # # Improvements: Added examples, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # # ─── Configuration ───────────────────────────────────────── # CONFIG = { # "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "img_size": 512 # } # # ─── Device Setup ────────────────────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Load Model ──────────────────────────────────────────── # model = smp.UnetPlusPlus( # encoder_name='resnet34', # encoder_depth=5, # encoder_weights='imagenet', # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ).to(device) # model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device)) # model.eval() # # ─── Transform Function ──────────────────────────────────── # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(CONFIG["img_size"]) # # ─── Prediction Function ─────────────────────────────────── # def predict(image): # image = image.convert("L") # Ensure grayscale # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # with torch.no_grad(): # pred = torch.sigmoid(model(img_tensor)) # mask = (pred > 0.5).float().cpu().squeeze().numpy() # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # return mask_img # # ─── Gradio UI (Improved) ────────────────────────────────── # examples = [ # ["examples/example1.jpg"], # ["examples/example2.jpg"], # ["examples/example3.jpg"], # ["examples/example4.jpg"], # ["examples/example5.jpg"], # ["examples/example6.jpg"], # ["examples/example7.jpg"] # ] # css = """ # .gradio-container { # max-width: 950px; # margin: auto; # } # .gr-button { # background-color: #4a90e2; # color: white; # border-radius: 5px; # } # .gr-button:hover { # background-color: #357ABD; # } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("## 🧬 Fibril Segmentation with UNet++") # gr.Markdown("Upload a **grayscale microscopy image**, and this model will predict the **segmentation mask of fibrillar structures**.\n\nModel: ResNet34 encoder + UNet++ decoder") # with gr.Row(): # input_img = gr.Image(label="Upload Microscopy Image", type="pil") # output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil") # submit_btn = gr.Button("Segment Image") # submit_btn.click(fn=predict, inputs=input_img, outputs=output_mask) # gr.Examples( # examples=examples, # inputs=input_img, # label="Try with Example Images", # cache_examples=False # ) # # ─── Launch App ──────────────────────────────────────────── # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.2.0 ++++++++++++++++++++++ # Last Updated: 08 July 2025 # Improvements: Added model selection, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # # ─── Device Setup ────────────────────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Model Configurations ────────────────────────────────── # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus" # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet" # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus" # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus" # } # } # # ─── Transform Function ──────────────────────────────────── # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # # ─── Model Loader ────────────────────────────────────────── # # def load_model(model_name): # # config = MODEL_OPTIONS[model_name] # # if config["architecture"] == "UnetPlusPlus": # # model = smp.UnetPlusPlus( # # encoder_name=config["encoder"], # # encoder_weights="imagenet", # # decoder_channels=(256, 128, 64, 32, 16), # # in_channels=1, # # classes=1, # # activation=None # # ) # # elif config["architecture"] == "Unet": # # model = smp.Unet( # # encoder_name=config["encoder"], # # encoder_weights="imagenet", # # decoder_channels=(256, 128, 64, 32, 16), # # in_channels=1, # # classes=1, # # activation=None # # ) # # else: # # raise ValueError(f"Unsupported architecture: {config['architecture']}") # # model.load_state_dict(torch.load(config["path"], map_location=device)) # # model.eval() # # return model.to(device) # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # # decoder_channels=(256, 128, 64, 32, 16), # Not used in DeepLabV3Plus # in_channels=1, # classes=1, # activation=None # ) # else: # raise ValueError(f"Unsupported architecture: {config['architecture']}") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # # ─── Prediction Function ─────────────────────────────────── # def predict(image, model_name): # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # with torch.no_grad(): # pred = torch.sigmoid(model(img_tensor)) # mask = (pred > 0.5).float().cpu().squeeze().numpy() # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # return mask_img # # ─── Example Images ──────────────────────────────────────── # examples = [ # ["examples/example1.jpg"], # ["examples/example2.jpg"], # ["examples/example3.jpg"], # ["examples/example4.jpg"], # ["examples/example5.jpg"], # ["examples/example6.jpg"], # ["examples/example7.jpg"] # ] # # ─── Custom CSS ──────────────────────────────────────────── # css = """ # .gradio-container { # max-width: 950px; # margin: auto; # } # .gr-button { # background-color: #4a90e2; # color: white; # border-radius: 5px; # } # .gr-button:hover { # background-color: #357ABD; # } # """ # # ─── Gradio UI ───────────────────────────────────────────── # with gr.Blocks(css=css) as demo: # gr.Markdown("## 🧬 Fibril Segmentation Interface") # gr.Markdown("Choose a model and upload a grayscale microscopy image. The model will predict the **fibrillar structure mask**.") # with gr.Row(): # model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Select Model") # with gr.Row(): # input_img = gr.Image(label="Upload Microscopy Image", type="pil") # output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil") # submit_btn = gr.Button("Segment Image") # submit_btn.click(fn=predict, inputs=[input_img, model_selector], outputs=output_mask) # gr.Examples( # examples=examples, # inputs=input_img, # label="Try with Example Images", # cache_examples=False # ) # # ─── Launch App ──────────────────────────────────────────── # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.3.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # # ─── Device Setup ────────────────────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Model Configurations ────────────────────────────────── # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus", # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet", # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus", # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus", # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." # } # } # # ─── Transform Function ──────────────────────────────────── # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # # ─── Model Loader ────────────────────────────────────────── # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # in_channels=1, # classes=1, # activation=None # ) # else: # raise ValueError(f"Unsupported architecture: {config['architecture']}") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # # ─── Prediction Function ─────────────────────────────────── # @torch.no_grad() # def predict(image, model_name, threshold): # if image is None: # return "❌ Please upload an image.", None, None, None # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # pred = torch.sigmoid(model(img_tensor)) # prob_map = pred.cpu().squeeze().numpy() # mask = (prob_map > threshold).astype(np.float32) # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # prob_img = Image.fromarray((prob_map * 255).astype(np.uint8)) # # Save mask temporarily for download # out_path = "predicted_mask.png" # mask_img.save(out_path) # return "✅ Done!", mask_img, prob_img, out_path # # ─── UI and Layout ───────────────────────────────────────── # css = """ # .gradio-container { # max-width: 950px; # margin: auto; # } # .gr-button { # background-color: #4a90e2; # color: white; # border-radius: 5px; # } # .gr-button:hover { # background-color: #357ABD; # } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("## 🧬 Fibril Segmentation Interface") # gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.") # with gr.Row(): # model_selector = gr.Dropdown( # choices=list(MODEL_OPTIONS.keys()), # value="UNet++ (ResNet34)", # label="Select Model" # ) # threshold_slider = gr.Slider( # minimum=0.0, maximum=1.0, value=0.5, step=0.01, # label="Segmentation Threshold" # ) # model_info = gr.Textbox( # label="Model Description", # lines=3, # interactive=False, # value=MODEL_OPTIONS[model_selector.value]["description"] # ) # def update_model_info(model_name): # return MODEL_OPTIONS[model_name]["description"] # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) # with gr.Tabs(): # with gr.Tab("🖼️ Segmentation"): # with gr.Row(): # input_img = gr.Image( # label="Upload Microscopy Image", # type="pil", # ) # output_mask = gr.Image(label="Predicted Mask", type="pil") # confidence_map = gr.Image(label="Confidence Map", type="pil") # submit_btn = gr.Button("Segment Image") # status_box = gr.Textbox(label="Status", interactive=False) # output_download = gr.File(label="Download Predicted Mask") # submit_btn.click( # fn=predict, # inputs=[input_img, model_selector, threshold_slider], # outputs=[status_box, output_mask, confidence_map, output_download] # ) # with gr.Tab("📁 Examples"): # gr.Examples( # examples=[ # ["examples/example1.jpg"], # ["examples/example2.jpg"], # ["examples/example3.jpg"], # ["examples/example4.jpg"], # ["examples/example5.jpg"], # ["examples/example6.jpg"], # ["examples/example7.jpg"] # ], # inputs=input_img, # label="Try with Example Images", # cache_examples=False # ) # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.4.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # # ─── Device Setup ────────────────────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Model Configurations ────────────────────────────────── # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus", # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet", # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus", # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus", # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." # } # } # # ─── Transform Function ──────────────────────────────────── # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # # ─── Model Loader ────────────────────────────────────────── # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], # encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, # classes=1, # activation=None # ) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], # encoder_weights="imagenet", # in_channels=1, # classes=1, # activation=None # ) # else: # raise ValueError(f"Unsupported architecture: {config['architecture']}") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # # ─── Prediction Function ─────────────────────────────────── # @torch.no_grad() # def predict(image, model_name, threshold): # if image is None: # return "❌ Please upload an image.", None, None, None # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # pred = torch.sigmoid(model(img_tensor)) # prob_map = pred.cpu().squeeze().numpy() # mask = (prob_map > threshold).astype(np.float32) # mask_img = Image.fromarray((mask * 255).astype(np.uint8)) # prob_img = Image.fromarray((prob_map * 255).astype(np.uint8)) # out_path = "predicted_mask.png" # mask_img.save(out_path) # return "✅ Done!", mask_img, prob_img, out_path # # ─── UI and Layout ───────────────────────────────────────── # css = """ # .gradio-container { # max-width: 1100px; # margin: auto; # } # .gr-button { # background-color: #4a90e2; # color: white; # border-radius: 5px; # } # .gr-button:hover { # background-color: #357ABD; # } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("## 🧬 Fibril Segmentation Interface") # gr.Markdown("Upload a grayscale microscopy image, choose a model, adjust threshold, and get a fibrillar structure mask.") # with gr.Row(): # # Left Panel – Examples and Upload # with gr.Column(scale=1): # gr.Markdown("### 📁 Example Images") # input_img = gr.Image( # label="Upload Microscopy Image", # type="pil", # interactive=True # ) # gr.Examples( # examples=[ # ["examples/example1.jpg"], # ["examples/example2.jpg"], # ["examples/example3.jpg"], # ["examples/example4.jpg"], # ["examples/example5.jpg"], # ["examples/example6.jpg"], # ["examples/example7.jpg"] # ], # inputs=input_img, # label="Try with Example Images", # cache_examples=False # ) # # Right Panel – Controls & Output # with gr.Column(scale=2): # with gr.Row(): # model_selector = gr.Dropdown( # choices=list(MODEL_OPTIONS.keys()), # value="UNet++ (ResNet34)", # label="Select Model" # ) # threshold_slider = gr.Slider( # minimum=0.0, maximum=1.0, value=0.5, step=0.01, # label="Segmentation Threshold" # ) # model_info = gr.Textbox( # label="Model Description", # lines=3, # interactive=False, # value=MODEL_OPTIONS["UNet++ (ResNet34)"]["description"] # ) # def update_model_info(model_name): # return MODEL_OPTIONS[model_name]["description"] # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) # with gr.Row(): # output_mask = gr.Image(label="Predicted Mask", type="pil") # confidence_map = gr.Image(label="Confidence Map", type="pil") # submit_btn = gr.Button("Segment Image") # status_box = gr.Textbox(label="Status", interactive=False) # output_download = gr.File(label="Download Predicted Mask") # submit_btn.click( # fn=predict, # inputs=[input_img, model_selector, threshold_slider], # outputs=[status_box, output_mask, confidence_map, output_download] # ) # # ─── Launch App ──────────────────────────────────────────── # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.5.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # from skimage import filters, measure, morphology # # ─── Device Setup ───────────────────────────── # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # print(f"✅ Using device: {device}") # # ─── Model Configurations ───────────────────── # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus", # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet", # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus", # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus", # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." # } # } # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], encoder_weights="imagenet", # in_channels=1, classes=1, activation=None) # else: # raise ValueError("Unsupported architecture.") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # # ─── Prediction Logic ───────────────────────── # @torch.no_grad() # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats): # if image is None: # return "❌ Please upload an image.", None, None, None, "", "" # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() # if use_otsu: # threshold = filters.threshold_otsu(pred) # binary_mask = (pred > threshold).astype(np.float32) # # Post-process mask # if remove_noise: # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) # if fill_holes: # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) # binary_mask = binary_mask.astype(np.float32) # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) # if show_overlay: # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") # overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) # else: # overlay_img = None # stats_text = "" # if show_stats: # labeled_mask = measure.label(binary_mask) # area = np.sum(binary_mask) # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 # stats_text = ( # f"🧮 Stats:\n - Area (px): {area:.0f}\n" # f" - Objects: {labeled_mask.max()}\n" # f" - Mean Conf: {mean_conf:.3f}\n" # f" - Std Conf: {std_conf:.3f}" # ) # mask_img.save("predicted_mask.png") # return "✅ Done!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text # # ─── UI ─────────────────────────────────────── # css = """ # .gradio-container { max-width: 1100px; margin: auto; } # .gr-button { background-color: #4a90e2; color: white; border-radius: 5px; } # .gr-button:hover { background-color: #357ABD; } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("## 🧬 Fibril Segmentation Interface") # gr.Markdown("Upload a grayscale microscopy image and choose options to segment fibrillar structures.") # with gr.Row(): # with gr.Column(scale=1): # input_img = gr.Image(label="Upload Image", type="pil", interactive=True) # gr.Examples( # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], # inputs=input_img, # label="📁 Try Example Images", cache_examples=False # ) # with gr.Column(scale=2): # model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Model") # threshold_slider = gr.Slider(0, 1, value=0.5, step=0.01, label="Segmentation Threshold") # use_otsu = gr.Checkbox(label="Use Otsu Threshold", value=False) # remove_noise = gr.Checkbox(label="Remove Small Objects", value=True) # fill_holes = gr.Checkbox(label="Fill Holes in Mask", value=True) # show_overlay = gr.Checkbox(label="Show Overlay on Original", value=True) # show_stats = gr.Checkbox(label="Show Area & Confidence Stats", value=True) # model_info = gr.Textbox(label="Model Description", interactive=False, lines=2) # def update_model_info(name): return MODEL_OPTIONS[name]["description"] # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) # submit = gr.Button("Segment Image") # status = gr.Textbox(label="Status", interactive=False) # with gr.Row(): # mask_output = gr.Image(label="Binary Mask") # prob_output = gr.Image(label="Confidence Map") # overlay_output = gr.Image(label="Overlay") # stats_output = gr.Textbox(label="Segmentation Stats", lines=4) # file_output = gr.File(label="Download Mask") # submit.click( # fn=predict, # inputs=[ # input_img, model_selector, threshold_slider, use_otsu, # remove_noise, fill_holes, show_overlay, show_stats # ], # outputs=[ # status, mask_output, prob_output, overlay_output, # file_output, stats_output # ] # ) # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.6.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling # import os # import torch # import numpy as np # from PIL import Image # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # from skimage import filters, measure, morphology # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus", # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet", # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus", # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus", # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." # } # } # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], encoder_weights="imagenet", # in_channels=1, classes=1, activation=None) # else: # raise ValueError("Unsupported architecture.") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # @torch.no_grad() # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats): # if image is None: # return "❌ Please upload an image.", None, None, None, "", "" # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() # if use_otsu: # threshold = filters.threshold_otsu(pred) # binary_mask = (pred > threshold).astype(np.float32) # if remove_noise: # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) # if fill_holes: # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) # binary_mask = binary_mask.astype(np.float32) # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) # if show_overlay: # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") # overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) # else: # overlay_img = None # stats_text = "" # if show_stats: # labeled_mask = measure.label(binary_mask) # area = np.sum(binary_mask) # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 # stats_text = ( # f"🧮 Stats:\n - Area (px): {area:.0f}\n" # f" - Objects: {labeled_mask.max()}\n" # f" - Mean Conf: {mean_conf:.3f}\n" # f" - Std Conf: {std_conf:.3f}" # ) # mask_img.save("predicted_mask.png") # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text # css = """ # body { # background: #f9fafb; # color: #2c3e50; # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; # } # h1, h2, h3 { # color: #34495e; # margin-bottom: 0.2em; # } # .gradio-container { # max-width: 1100px; # margin: 1.5rem auto; # padding: 1rem 2rem; # } # .gr-button { # background-color: #0078d7; # color: white; # font-weight: 600; # border-radius: 8px; # padding: 12px 25px; # } # .gr-button:hover { # background-color: #005a9e; # } # .gr-slider label, .gr-checkbox label { # font-weight: 600; # color: #34495e; # } # .gr-image input[type="file"] { # border-radius: 8px; # } # .gr-file label { # font-weight: 600; # } # .gr-textbox textarea { # font-family: monospace; # font-size: 0.9rem; # background: #ecf0f1; # border-radius: 6px; # padding: 8px; # } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("
Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.
") # with gr.Row(): # with gr.Column(scale=1): # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) # gr.Examples( # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], # inputs=input_img, # label="📁 Try Example Images", # cache_examples=False, # elem_id="examples" # ) # with gr.Column(scale=1): # model_selector = gr.Dropdown( # choices=list(MODEL_OPTIONS.keys()), # value="UNet++ (ResNet34)", # label="Select Model", # interactive=True # ) # model_info = gr.Textbox( # label="Model Description", # interactive=False, # lines=3, # max_lines=5, # elem_id="model-desc", # show_label=True, # container=True # ) # def update_model_info(name): # return MODEL_OPTIONS[name]["description"] # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) # gr.Markdown("### Segmentation Options") # threshold_slider = gr.Slider( # minimum=0, maximum=1, value=0.5, step=0.01, # label="Segmentation Threshold", # interactive=True, # info="Adjust threshold for binarizing segmentation probability." # ) # use_otsu = gr.Checkbox( # label="Use Otsu Threshold", # value=False, # info="Automatically select optimal threshold using Otsu's method." # ) # remove_noise = gr.Checkbox( # label="Remove Small Objects", # value=False, # info="Remove small noise blobs from mask." # ) # fill_holes = gr.Checkbox( # label="Fill Holes in Mask", # value=True, # info="Fill small holes inside segmented objects." # ) # show_overlay = gr.Checkbox( # label="Show Overlay on Original", # value=True, # info="Display the mask overlaid on the original image." # ) # show_stats = gr.Checkbox( # label="Show Area & Confidence Stats", # value=True, # info="Display segmentation statistics like area and confidence." # ) # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") # gr.Markdown("---") # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") # with gr.Row(): # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") # overlay_output = gr.Image(label="Overlay", interactive=False, type="pil") # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=6, elem_id="stats") # file_output = gr.File(label="Download Segmentation Mask") # submit.click( # fn=predict, # inputs=[ # input_img, model_selector, threshold_slider, use_otsu, # remove_noise, fill_holes, show_overlay, show_stats # ], # outputs=[ # status, mask_output, prob_output, overlay_output, # file_output, stats_output # ] # ) # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.7.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ # import os # import torch # import numpy as np # from PIL import Image, ImageDraw, ImageFont # import albumentations as A # from albumentations.pytorch import ToTensorV2 # import segmentation_models_pytorch as smp # import gradio as gr # from skimage import filters, measure, morphology # device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # MODEL_OPTIONS = { # "UNet++ (ResNet34)": { # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "UnetPlusPlus", # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." # }, # "UNet (ResNet34)": { # "path": "./model/unet_fibril_seg_model.pth", # "encoder": "resnet34", # "architecture": "Unet", # "description": "Classic UNet with ResNet34 encoder — fast and lightweight." # }, # "UNet++ (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "UnetPlusPlus", # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." # }, # "DeepLabV3Plus (efficientnet-b3)": { # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", # "encoder": "efficientnet-b3", # "architecture": "DeepLabV3Plus", # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." # } # } # def get_transform(size): # return A.Compose([ # A.Resize(size, size), # A.Normalize(mean=(0.5,), std=(0.5,)), # ToTensorV2() # ]) # transform = get_transform(512) # model_cache = {} # def load_model(model_name): # if model_name in model_cache: # return model_cache[model_name] # config = MODEL_OPTIONS[model_name] # if config["architecture"] == "UnetPlusPlus": # model = smp.UnetPlusPlus( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "Unet": # model = smp.Unet( # encoder_name=config["encoder"], encoder_weights="imagenet", # decoder_channels=(256, 128, 64, 32, 16), # in_channels=1, classes=1, activation=None) # elif config["architecture"] == "DeepLabV3Plus": # model = smp.DeepLabV3Plus( # encoder_name=config["encoder"], encoder_weights="imagenet", # in_channels=1, classes=1, activation=None) # else: # raise ValueError("Unsupported architecture.") # model.load_state_dict(torch.load(config["path"], map_location=device)) # model.eval() # model_cache[model_name] = model.to(device) # return model_cache[model_name] # def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size): # """ # Draws numbers on the overlay image labeling oligomers and fibrils. # Oligomers labeled as O1, O2, ... # Fibrils labeled as F1, F2, ... # """ # overlay = orig_img.convert("RGB").copy() # draw = ImageDraw.Draw(overlay) # # Try to get a nice font; fallback to default if not available # try: # font = ImageFont.truetype("arial.ttf", 18) # except IOError: # font = ImageFont.load_default() # labeled_mask = measure.label(binary_mask) # regions = measure.regionprops(labeled_mask) # oligomer_count = 0 # fibril_count = 0 # for region in regions: # area = region.area # centroid = region.centroid # (row, col) # x, y = int(centroid[1]), int(centroid[0]) # if area <= max_oligomer_size: # oligomer_count += 1 # label_text = f"O{oligomer_count}" # label_color = (0, 255, 0) # Green for oligomers # else: # fibril_count += 1 # label_text = f"F{fibril_count}" # label_color = (255, 0, 0) # Red for fibrils # # Draw circle around centroid for visibility # r = 12 # draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2) # # Draw label text # bbox = draw.textbbox((0, 0), label_text, font=font) # text_width = bbox[2] - bbox[0] # text_height = bbox[3] - bbox[1] # text_pos = (x - text_width // 2, y - text_height // 2) # draw.text(text_pos, label_text, fill=label_color, font=font) # return overlay, oligomer_count, fibril_count # @torch.no_grad() # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers): # if image is None: # return "❌ Please upload an image.", None, None, None, "", "", "", "" # image = image.convert("L") # img_np = np.array(image) # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) # model = load_model(model_name) # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() # if use_otsu: # threshold = filters.threshold_otsu(pred) # binary_mask = (pred > threshold).astype(np.float32) # if remove_noise: # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) # if fill_holes: # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) # binary_mask = binary_mask.astype(np.float32) # # If user wants to keep only oligomers, remove large fibrils from mask # if keep_only_oligomers: # labeled_mask = measure.label(binary_mask) # regions = measure.regionprops(labeled_mask) # filtered_mask = np.zeros_like(binary_mask) # for region in regions: # if region.area <= max_oligomer_size: # filtered_mask[labeled_mask == region.label] = 1 # binary_mask = filtered_mask.astype(np.float32) # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) # prob_img = Image.fromarray((pred * 255).astype(np.uint8)) # overlay_img = None # oligomer_count = 0 # fibril_count = 0 # if show_overlay: # # Resize mask to original image size # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") # # Blend original and mask # base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) # # Draw labels on overlay # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size) # else: # overlay_img = None # stats_text = "" # if show_stats: # labeled_mask = measure.label(binary_mask) # area = np.sum(binary_mask) # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 # total_objects = labeled_mask.max() # # Count oligomers and fibrils # oligomers = 0 # fibrils = 0 # for region in measure.regionprops(labeled_mask): # if region.area <= max_oligomer_size: # oligomers += 1 # else: # fibrils += 1 # stats_text = ( # f"🧮 Stats:\n" # f" - Area (px): {area:.0f}\n" # f" - Total Objects: {total_objects}\n" # f" - Oligomers (small): {oligomers}\n" # f" - Fibrils (large): {fibrils}\n" # f" - Mean Confidence: {mean_conf:.3f}\n" # f" - Std Confidence: {std_conf:.3f}" # ) # mask_img.save("predicted_mask.png") # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count) # css = """ # body { # background: #f9fafb; # color: #2c3e50; # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; # } # h1, h2, h3 { # color: #34495e; # margin-bottom: 0.2em; # } # .gradio-container { # max-width: 1100px; # margin: 1.5rem auto; # padding: 1rem 2rem; # } # .gr-button { # background-color: #0078d7; # color: white; # font-weight: 600; # border-radius: 8px; # padding: 12px 25px; # } # .gr-button:hover { # background-color: #005a9e; # } # .gr-slider label, .gr-checkbox label { # font-weight: 600; # color: #34495e; # } # .gr-image input[type="file"] { # border-radius: 8px; # } # .gr-file label { # font-weight: 600; # } # .gr-textbox textarea { # font-family: monospace; # font-size: 0.9rem; # background: #ecf0f1; # border-radius: 6px; # padding: 8px; # } # """ # with gr.Blocks(css=css) as demo: # gr.Markdown("Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.
") # with gr.Row(): # with gr.Column(scale=1): # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) # gr.Examples( # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], # inputs=input_img, # label="📁 Try Example Images", # cache_examples=False, # elem_id="examples" # ) # with gr.Column(scale=1): # model_selector = gr.Dropdown( # choices=list(MODEL_OPTIONS.keys()), # value="UNet++ (ResNet34)", # label="Select Model", # interactive=True # ) # model_info = gr.Textbox( # label="Model Description", # interactive=False, # lines=3, # max_lines=5, # elem_id="model-desc", # show_label=True, # container=True # ) # def update_model_info(name): # return MODEL_OPTIONS[name]["description"] # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) # gr.Markdown("### Segmentation Options") # threshold_slider = gr.Slider( # minimum=0, maximum=1, value=0.5, step=0.01, # label="Segmentation Threshold", # interactive=True, # info="Adjust threshold for binarizing segmentation probability." # ) # use_otsu = gr.Checkbox( # label="Use Otsu Threshold", # value=False, # info="Automatically select optimal threshold using Otsu's method." # ) # remove_noise = gr.Checkbox( # label="Remove Small Objects", # value=False, # info="Remove small noise blobs from mask." # ) # fill_holes = gr.Checkbox( # label="Fill Holes in Mask", # value=True, # info="Fill small holes inside segmented objects." # ) # show_overlay = gr.Checkbox( # label="Show Overlay on Original", # value=True, # info="Display the mask overlaid on the original image." # ) # show_stats = gr.Checkbox( # label="Show Area & Confidence Stats", # value=True, # info="Display segmentation statistics like area and confidence." # ) # max_oligomer_size = gr.Slider( # minimum=10, maximum=1000, value=400, step=10, # label="Max Oligomer Size (px)", # interactive=True, # info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils." # ) # keep_only_oligomers = gr.Checkbox( # label="Keep Only Oligomers (Remove Large Fibrils)", # value=False, # info="If enabled, only oligomers remain in the final mask." # ) # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") # gr.Markdown("---") # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") # with gr.Row(): # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") # overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil") # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats") # # Optional: show counts separately with big text # oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1) # fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1) # file_output = gr.File(label="Download Segmentation Mask") # submit.click( # fn=predict, # inputs=[ # input_img, model_selector, threshold_slider, use_otsu, # remove_noise, fill_holes, show_overlay, show_stats, # max_oligomer_size, keep_only_oligomers # ], # outputs=[ # status, mask_output, prob_output, overlay_output, # file_output, stats_output, oligomer_count_txt, fibril_count_txt # ] # ) # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) # if __name__ == "__main__": # demo.launch() # +++++++++++++++ Final Version: 1.8.0 ++++++++++++++++++++++ # Last Updated: 10 July 2025 # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ import os import torch import numpy as np from PIL import Image, ImageDraw, ImageFont import albumentations as A from albumentations.pytorch import ToTensorV2 import segmentation_models_pytorch as smp import gradio as gr from skimage import filters, measure, morphology device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL_OPTIONS = { "UNet++ (ResNet34)": { "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth", "encoder": "resnet34", "architecture": "UnetPlusPlus", "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy." }, "UNet (ResNet34)": { "path": "./model/unet_fibril_seg_model.pth", "encoder": "resnet34", "architecture": "Unet", "description": "Classic UNet with ResNet34 encoder — fast and lightweight." }, "UNet++ (efficientnet-b3)": { "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth", "encoder": "efficientnet-b3", "architecture": "UnetPlusPlus", "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU." }, "DeepLabV3Plus (efficientnet-b3)": { "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth", "encoder": "efficientnet-b3", "architecture": "DeepLabV3Plus", "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU." } } def get_transform(size): return A.Compose([ A.Resize(size, size), A.Normalize(mean=(0.5,), std=(0.5,)), ToTensorV2() ]) transform = get_transform(512) model_cache = {} def load_model(model_name): if model_name in model_cache: return model_cache[model_name] config = MODEL_OPTIONS[model_name] if config["architecture"] == "UnetPlusPlus": model = smp.UnetPlusPlus( encoder_name=config["encoder"], encoder_weights="imagenet", decoder_channels=(256, 128, 64, 32, 16), in_channels=1, classes=1, activation=None) elif config["architecture"] == "Unet": model = smp.Unet( encoder_name=config["encoder"], encoder_weights="imagenet", decoder_channels=(256, 128, 64, 32, 16), in_channels=1, classes=1, activation=None) elif config["architecture"] == "DeepLabV3Plus": model = smp.DeepLabV3Plus( encoder_name=config["encoder"], encoder_weights="imagenet", in_channels=1, classes=1, activation=None) else: raise ValueError("Unsupported architecture.") model.load_state_dict(torch.load(config["path"], map_location=device)) model.eval() model_cache[model_name] = model.to(device) return model_cache[model_name] def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size, fibril_length_thresh=100): """ Draw labels on the overlay image based on circularity and length. Oligomers: High circularity (circle-like). Fibrils: Long objects (based on major axis length). fibril_length_thresh: Length threshold to define fibrils. """ overlay = orig_img.convert("RGB").copy() draw = ImageDraw.Draw(overlay) try: font = ImageFont.truetype("arial.ttf", 18) except IOError: font = ImageFont.load_default() labeled_mask = measure.label(binary_mask) regions = measure.regionprops(labeled_mask) oligomer_count = 0 fibril_count = 0 for region in regions: area = region.area perimeter = region.perimeter if region.perimeter > 0 else 1 # prevent div by zero circularity = 4 * np.pi * area / (perimeter ** 2) major_length = region.major_axis_length centroid = region.centroid x, y = int(centroid[1]), int(centroid[0]) # Thresholds to tune circularity_thresh = 1 # close to circle # max_oligomer_size can still be used for area-based filtering if desired # Classification logic if circularity >= circularity_thresh and area <= max_oligomer_size: oligomer_count += 1 label_text = f"O{oligomer_count}" label_color = (0, 255, 0) # Green for oligomers elif major_length >= fibril_length_thresh: fibril_count += 1 label_text = f"F{fibril_count}" label_color = (255, 0, 0) # Red for fibrils else: # If neither circular nor long, you can optionally skip labeling or classify as fibril fibril_count += 1 label_text = f"F{fibril_count}" label_color = (255, 0, 0) r = 12 draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2) bbox = draw.textbbox((0, 0), label_text, font=font) text_width = bbox[2] - bbox[0] text_height = bbox[3] - bbox[1] text_pos = (x - text_width // 2, y - text_height // 2) draw.text(text_pos, label_text, fill=label_color, font=font) return overlay, oligomer_count, fibril_count @torch.no_grad() def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers): if image is None: return "❌ Please upload an image.", None, None, None, "", "", "", "" image = image.convert("L") img_np = np.array(image) img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device) model = load_model(model_name) pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy() if use_otsu: threshold = filters.threshold_otsu(pred) binary_mask = (pred > threshold).astype(np.float32) if remove_noise: binary_mask = morphology.remove_small_objects(binary_mask > 0, 64) if fill_holes: binary_mask = morphology.remove_small_holes(binary_mask > 0, 64) binary_mask = binary_mask.astype(np.float32) # If user wants to keep only oligomers, remove large fibrils from mask if keep_only_oligomers: labeled_mask = measure.label(binary_mask) regions = measure.regionprops(labeled_mask) filtered_mask = np.zeros_like(binary_mask) for region in regions: if region.area <= max_oligomer_size: filtered_mask[labeled_mask == region.label] = 1 binary_mask = filtered_mask.astype(np.float32) mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8)) prob_img = Image.fromarray((pred * 255).astype(np.uint8)) overlay_img = None oligomer_count = 0 fibril_count = 0 if show_overlay: # Resize mask to original image size mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB") # Blend original and mask base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4) # Draw labels on overlay # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size) overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size, fibril_length_thresh=100) else: overlay_img = None stats_text = "" if show_stats: labeled_mask = measure.label(binary_mask) area = np.sum(binary_mask) mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0 std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0 total_objects = labeled_mask.max() # Count oligomers and fibrils oligomers = 0 fibrils = 0 for region in measure.regionprops(labeled_mask): if region.area <= max_oligomer_size: oligomers += 1 else: fibrils += 1 stats_text = ( f"🧮 Stats:\n" f" - Area (px): {area:.0f}\n" f" - Total Objects: {total_objects}\n" f" - Oligomers (small): {oligomers}\n" f" - Fibrils (large): {fibrils}\n" # f" - Mean Confidence: {mean_conf:.3f}\n" # f" - Std Confidence: {std_conf:.3f}" ) mask_img.save("predicted_mask.png") return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count) css = """ body { background: #f9fafb; color: #2c3e50; font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif; } h1, h2, h3 { color: #34495e; margin-bottom: 0.2em; } .gradio-container { max-width: 1100px; margin: 1.5rem auto; padding: 1rem 2rem; } .gr-button { background-color: #0078d7; color: white; font-weight: 600; border-radius: 8px; padding: 12px 25px; } .gr-button:hover { background-color: #005a9e; } .gr-slider label, .gr-checkbox label { font-weight: 600; color: #34495e; } .gr-image input[type="file"] { border-radius: 8px; } .gr-file label { font-weight: 600; } .gr-textbox textarea { font-family: monospace; font-size: 0.9rem; background: #ecf0f1; border-radius: 6px; padding: 8px; } """ with gr.Blocks(css=css) as demo: gr.Markdown("Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.
") with gr.Row(): with gr.Column(scale=1): input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"]) gr.Examples( examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)], inputs=input_img, label="📁 Try Example Images", cache_examples=False, elem_id="examples" ) with gr.Column(scale=1): model_selector = gr.Dropdown( choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Select Model", interactive=True ) model_info = gr.Textbox( label="Model Description", interactive=False, lines=3, max_lines=5, elem_id="model-desc", show_label=True, container=True ) def update_model_info(name): return MODEL_OPTIONS[name]["description"] model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info) gr.Markdown("### Segmentation Options") threshold_slider = gr.Slider( minimum=0, maximum=1, value=0.5, step=0.01, label="Segmentation Threshold", interactive=True, info="Adjust threshold for binarizing segmentation probability." ) use_otsu = gr.Checkbox( label="Use Otsu Threshold", value=False, info="Automatically select optimal threshold using Otsu's method." ) remove_noise = gr.Checkbox( label="Remove Small Objects", value=False, info="Remove small noise blobs from mask." ) fill_holes = gr.Checkbox( label="Fill Holes in Mask", value=True, info="Fill small holes inside segmented objects." ) show_overlay = gr.Checkbox( label="Show Overlay on Original", value=True, info="Display the mask overlaid on the original image." ) show_stats = gr.Checkbox( label="Show Area & Confidence Stats", value=True, info="Display segmentation statistics like area and confidence." ) max_oligomer_size = gr.Slider( minimum=10, maximum=1000, value=400, step=10, label="Max Oligomer Size (px)", interactive=True, info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils." ) keep_only_oligomers = gr.Checkbox( label="Keep Only Oligomers (Remove Large Fibrils)", value=False, info="If enabled, only oligomers remain in the final mask." ) submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn") gr.Markdown("---") status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg") with gr.Row(): mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil") prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil") overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil") stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats") # Optional: show counts separately with big text oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1) fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1) file_output = gr.File(label="Download Segmentation Mask") submit.click( fn=predict, inputs=[ input_img, model_selector, threshold_slider, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers ], outputs=[ status, mask_output, prob_output, overlay_output, file_output, stats_output, oligomer_count_txt, fibril_count_txt ] ) demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info) if __name__ == "__main__": demo.launch()