himanshuch8055 commited on
Commit
12bc53e
·
1 Parent(s): 41d5edc

Add model selection, enhanced UI, and oligomer/fibril labeling; update segmentation functionality

Browse files
Files changed (2) hide show
  1. app.py +776 -14
  2. predicted_mask.png +0 -0
app.py CHANGED
@@ -1000,10 +1000,659 @@
1000
  # Last Updated: 10 July 2025
1001
  # Improvements: Added model selection, better UI, and device handling
1002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1003
  import os
1004
  import torch
1005
  import numpy as np
1006
- from PIL import Image
1007
  import albumentations as A
1008
  from albumentations.pytorch import ToTensorV2
1009
  import segmentation_models_pytorch as smp
@@ -1074,10 +1723,72 @@ def load_model(model_name):
1074
  model_cache[model_name] = model.to(device)
1075
  return model_cache[model_name]
1076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1077
  @torch.no_grad()
1078
- def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats):
1079
  if image is None:
1080
- return "❌ Please upload an image.", None, None, None, "", ""
1081
 
1082
  image = image.convert("L")
1083
  img_np = np.array(image)
@@ -1095,14 +1806,35 @@ def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, sh
1095
  binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
1096
  if fill_holes:
1097
  binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
 
1098
  binary_mask = binary_mask.astype(np.float32)
1099
 
 
 
 
 
 
 
 
 
 
 
1100
  mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
1101
  prob_img = Image.fromarray((pred * 255).astype(np.uint8))
1102
 
 
 
 
 
1103
  if show_overlay:
 
1104
  mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
1105
- overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
 
 
 
 
 
1106
  else:
1107
  overlay_img = None
1108
 
@@ -1112,17 +1844,30 @@ def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, sh
1112
  area = np.sum(binary_mask)
1113
  mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
1114
  std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
 
 
 
 
 
 
 
 
 
 
 
1115
  stats_text = (
1116
- f"🧮 Stats:\n - Area (px): {area:.0f}\n"
1117
- f" - Objects: {labeled_mask.max()}\n"
1118
- f" - Mean Conf: {mean_conf:.3f}\n"
1119
- f" - Std Conf: {std_conf:.3f}"
 
 
 
1120
  )
1121
 
1122
  mask_img.save("predicted_mask.png")
1123
 
1124
- return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text
1125
-
1126
 
1127
  css = """
1128
  body {
@@ -1236,6 +1981,18 @@ with gr.Blocks(css=css) as demo:
1236
  info="Display segmentation statistics like area and confidence."
1237
  )
1238
 
 
 
 
 
 
 
 
 
 
 
 
 
1239
  submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
1240
 
1241
  gr.Markdown("---")
@@ -1245,9 +2002,13 @@ with gr.Blocks(css=css) as demo:
1245
  with gr.Row():
1246
  mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
1247
  prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
1248
- overlay_output = gr.Image(label="Overlay", interactive=False, type="pil")
 
 
1249
 
1250
- stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=6, elem_id="stats")
 
 
1251
 
1252
  file_output = gr.File(label="Download Segmentation Mask")
1253
 
@@ -1255,11 +2016,12 @@ with gr.Blocks(css=css) as demo:
1255
  fn=predict,
1256
  inputs=[
1257
  input_img, model_selector, threshold_slider, use_otsu,
1258
- remove_noise, fill_holes, show_overlay, show_stats
 
1259
  ],
1260
  outputs=[
1261
  status, mask_output, prob_output, overlay_output,
1262
- file_output, stats_output
1263
  ]
1264
  )
1265
 
 
1000
  # Last Updated: 10 July 2025
1001
  # Improvements: Added model selection, better UI, and device handling
1002
 
1003
+ # import os
1004
+ # import torch
1005
+ # import numpy as np
1006
+ # from PIL import Image
1007
+ # import albumentations as A
1008
+ # from albumentations.pytorch import ToTensorV2
1009
+ # import segmentation_models_pytorch as smp
1010
+ # import gradio as gr
1011
+ # from skimage import filters, measure, morphology
1012
+
1013
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1014
+
1015
+ # MODEL_OPTIONS = {
1016
+ # "UNet++ (ResNet34)": {
1017
+ # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
1018
+ # "encoder": "resnet34",
1019
+ # "architecture": "UnetPlusPlus",
1020
+ # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
1021
+ # },
1022
+ # "UNet (ResNet34)": {
1023
+ # "path": "./model/unet_fibril_seg_model.pth",
1024
+ # "encoder": "resnet34",
1025
+ # "architecture": "Unet",
1026
+ # "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
1027
+ # },
1028
+ # "UNet++ (efficientnet-b3)": {
1029
+ # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
1030
+ # "encoder": "efficientnet-b3",
1031
+ # "architecture": "UnetPlusPlus",
1032
+ # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
1033
+ # },
1034
+ # "DeepLabV3Plus (efficientnet-b3)": {
1035
+ # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
1036
+ # "encoder": "efficientnet-b3",
1037
+ # "architecture": "DeepLabV3Plus",
1038
+ # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
1039
+ # }
1040
+ # }
1041
+
1042
+ # def get_transform(size):
1043
+ # return A.Compose([
1044
+ # A.Resize(size, size),
1045
+ # A.Normalize(mean=(0.5,), std=(0.5,)),
1046
+ # ToTensorV2()
1047
+ # ])
1048
+
1049
+ # transform = get_transform(512)
1050
+ # model_cache = {}
1051
+
1052
+ # def load_model(model_name):
1053
+ # if model_name in model_cache:
1054
+ # return model_cache[model_name]
1055
+ # config = MODEL_OPTIONS[model_name]
1056
+ # if config["architecture"] == "UnetPlusPlus":
1057
+ # model = smp.UnetPlusPlus(
1058
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1059
+ # decoder_channels=(256, 128, 64, 32, 16),
1060
+ # in_channels=1, classes=1, activation=None)
1061
+ # elif config["architecture"] == "Unet":
1062
+ # model = smp.Unet(
1063
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1064
+ # decoder_channels=(256, 128, 64, 32, 16),
1065
+ # in_channels=1, classes=1, activation=None)
1066
+ # elif config["architecture"] == "DeepLabV3Plus":
1067
+ # model = smp.DeepLabV3Plus(
1068
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1069
+ # in_channels=1, classes=1, activation=None)
1070
+ # else:
1071
+ # raise ValueError("Unsupported architecture.")
1072
+ # model.load_state_dict(torch.load(config["path"], map_location=device))
1073
+ # model.eval()
1074
+ # model_cache[model_name] = model.to(device)
1075
+ # return model_cache[model_name]
1076
+
1077
+ # @torch.no_grad()
1078
+ # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats):
1079
+ # if image is None:
1080
+ # return "❌ Please upload an image.", None, None, None, "", ""
1081
+
1082
+ # image = image.convert("L")
1083
+ # img_np = np.array(image)
1084
+ # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
1085
+
1086
+ # model = load_model(model_name)
1087
+ # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
1088
+
1089
+ # if use_otsu:
1090
+ # threshold = filters.threshold_otsu(pred)
1091
+
1092
+ # binary_mask = (pred > threshold).astype(np.float32)
1093
+
1094
+ # if remove_noise:
1095
+ # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
1096
+ # if fill_holes:
1097
+ # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
1098
+ # binary_mask = binary_mask.astype(np.float32)
1099
+
1100
+ # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
1101
+ # prob_img = Image.fromarray((pred * 255).astype(np.uint8))
1102
+
1103
+ # if show_overlay:
1104
+ # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
1105
+ # overlay_img = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
1106
+ # else:
1107
+ # overlay_img = None
1108
+
1109
+ # stats_text = ""
1110
+ # if show_stats:
1111
+ # labeled_mask = measure.label(binary_mask)
1112
+ # area = np.sum(binary_mask)
1113
+ # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
1114
+ # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
1115
+ # stats_text = (
1116
+ # f"🧮 Stats:\n - Area (px): {area:.0f}\n"
1117
+ # f" - Objects: {labeled_mask.max()}\n"
1118
+ # f" - Mean Conf: {mean_conf:.3f}\n"
1119
+ # f" - Std Conf: {std_conf:.3f}"
1120
+ # )
1121
+
1122
+ # mask_img.save("predicted_mask.png")
1123
+
1124
+ # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text
1125
+
1126
+
1127
+ # css = """
1128
+ # body {
1129
+ # background: #f9fafb;
1130
+ # color: #2c3e50;
1131
+ # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
1132
+ # }
1133
+ # h1, h2, h3 {
1134
+ # color: #34495e;
1135
+ # margin-bottom: 0.2em;
1136
+ # }
1137
+ # .gradio-container {
1138
+ # max-width: 1100px;
1139
+ # margin: 1.5rem auto;
1140
+ # padding: 1rem 2rem;
1141
+ # }
1142
+ # .gr-button {
1143
+ # background-color: #0078d7;
1144
+ # color: white;
1145
+ # font-weight: 600;
1146
+ # border-radius: 8px;
1147
+ # padding: 12px 25px;
1148
+ # }
1149
+ # .gr-button:hover {
1150
+ # background-color: #005a9e;
1151
+ # }
1152
+ # .gr-slider label, .gr-checkbox label {
1153
+ # font-weight: 600;
1154
+ # color: #34495e;
1155
+ # }
1156
+ # .gr-image input[type="file"] {
1157
+ # border-radius: 8px;
1158
+ # }
1159
+ # .gr-file label {
1160
+ # font-weight: 600;
1161
+ # }
1162
+ # .gr-textbox textarea {
1163
+ # font-family: monospace;
1164
+ # font-size: 0.9rem;
1165
+ # background: #ecf0f1;
1166
+ # border-radius: 6px;
1167
+ # padding: 8px;
1168
+ # }
1169
+ # """
1170
+
1171
+ # with gr.Blocks(css=css) as demo:
1172
+ # gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>")
1173
+ # gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>")
1174
+
1175
+ # with gr.Row():
1176
+ # with gr.Column(scale=1):
1177
+ # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"])
1178
+ # gr.Examples(
1179
+ # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
1180
+ # inputs=input_img,
1181
+ # label="📁 Try Example Images",
1182
+ # cache_examples=False,
1183
+ # elem_id="examples"
1184
+ # )
1185
+ # with gr.Column(scale=1):
1186
+ # model_selector = gr.Dropdown(
1187
+ # choices=list(MODEL_OPTIONS.keys()),
1188
+ # value="UNet++ (ResNet34)",
1189
+ # label="Select Model",
1190
+ # interactive=True
1191
+ # )
1192
+ # model_info = gr.Textbox(
1193
+ # label="Model Description",
1194
+ # interactive=False,
1195
+ # lines=3,
1196
+ # max_lines=5,
1197
+ # elem_id="model-desc",
1198
+ # show_label=True,
1199
+ # container=True
1200
+ # )
1201
+
1202
+ # def update_model_info(name):
1203
+ # return MODEL_OPTIONS[name]["description"]
1204
+ # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
1205
+
1206
+ # gr.Markdown("### Segmentation Options")
1207
+ # threshold_slider = gr.Slider(
1208
+ # minimum=0, maximum=1, value=0.5, step=0.01,
1209
+ # label="Segmentation Threshold",
1210
+ # interactive=True,
1211
+ # info="Adjust threshold for binarizing segmentation probability."
1212
+ # )
1213
+ # use_otsu = gr.Checkbox(
1214
+ # label="Use Otsu Threshold",
1215
+ # value=False,
1216
+ # info="Automatically select optimal threshold using Otsu's method."
1217
+ # )
1218
+ # remove_noise = gr.Checkbox(
1219
+ # label="Remove Small Objects",
1220
+ # value=False,
1221
+ # info="Remove small noise blobs from mask."
1222
+ # )
1223
+ # fill_holes = gr.Checkbox(
1224
+ # label="Fill Holes in Mask",
1225
+ # value=True,
1226
+ # info="Fill small holes inside segmented objects."
1227
+ # )
1228
+ # show_overlay = gr.Checkbox(
1229
+ # label="Show Overlay on Original",
1230
+ # value=True,
1231
+ # info="Display the mask overlaid on the original image."
1232
+ # )
1233
+ # show_stats = gr.Checkbox(
1234
+ # label="Show Area & Confidence Stats",
1235
+ # value=True,
1236
+ # info="Display segmentation statistics like area and confidence."
1237
+ # )
1238
+
1239
+ # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
1240
+
1241
+ # gr.Markdown("---")
1242
+
1243
+ # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg")
1244
+
1245
+ # with gr.Row():
1246
+ # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
1247
+ # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
1248
+ # overlay_output = gr.Image(label="Overlay", interactive=False, type="pil")
1249
+
1250
+ # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=6, elem_id="stats")
1251
+
1252
+ # file_output = gr.File(label="Download Segmentation Mask")
1253
+
1254
+ # submit.click(
1255
+ # fn=predict,
1256
+ # inputs=[
1257
+ # input_img, model_selector, threshold_slider, use_otsu,
1258
+ # remove_noise, fill_holes, show_overlay, show_stats
1259
+ # ],
1260
+ # outputs=[
1261
+ # status, mask_output, prob_output, overlay_output,
1262
+ # file_output, stats_output
1263
+ # ]
1264
+ # )
1265
+
1266
+ # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info)
1267
+
1268
+ # if __name__ == "__main__":
1269
+ # demo.launch()
1270
+
1271
+
1272
+
1273
+ # +++++++++++++++ Final Version: 1.7.0 ++++++++++++++++++++++
1274
+ # Last Updated: 10 July 2025
1275
+ # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling
1276
+ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1277
+
1278
+ # import os
1279
+ # import torch
1280
+ # import numpy as np
1281
+ # from PIL import Image, ImageDraw, ImageFont
1282
+ # import albumentations as A
1283
+ # from albumentations.pytorch import ToTensorV2
1284
+ # import segmentation_models_pytorch as smp
1285
+ # import gradio as gr
1286
+ # from skimage import filters, measure, morphology
1287
+
1288
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
1289
+
1290
+ # MODEL_OPTIONS = {
1291
+ # "UNet++ (ResNet34)": {
1292
+ # "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
1293
+ # "encoder": "resnet34",
1294
+ # "architecture": "UnetPlusPlus",
1295
+ # "description": "UNet++ with ResNet34 encoder — good balance of speed and accuracy."
1296
+ # },
1297
+ # "UNet (ResNet34)": {
1298
+ # "path": "./model/unet_fibril_seg_model.pth",
1299
+ # "encoder": "resnet34",
1300
+ # "architecture": "Unet",
1301
+ # "description": "Classic UNet with ResNet34 encoder — fast and lightweight."
1302
+ # },
1303
+ # "UNet++ (efficientnet-b3)": {
1304
+ # "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
1305
+ # "encoder": "efficientnet-b3",
1306
+ # "architecture": "UnetPlusPlus",
1307
+ # "description": "UNet++ with EfficientNet-B3 — more accurate, slower on CPU."
1308
+ # },
1309
+ # "DeepLabV3Plus (efficientnet-b3)": {
1310
+ # "path": "./model/encoder_efficientnet-b3_decoder_DeepLabV3Plus_fibril_seg_model.pth",
1311
+ # "encoder": "efficientnet-b3",
1312
+ # "architecture": "DeepLabV3Plus",
1313
+ # "description": "DeepLabV3+ with EfficientNet-B3 — high performance, slower on CPU."
1314
+ # }
1315
+ # }
1316
+
1317
+ # def get_transform(size):
1318
+ # return A.Compose([
1319
+ # A.Resize(size, size),
1320
+ # A.Normalize(mean=(0.5,), std=(0.5,)),
1321
+ # ToTensorV2()
1322
+ # ])
1323
+
1324
+ # transform = get_transform(512)
1325
+ # model_cache = {}
1326
+
1327
+ # def load_model(model_name):
1328
+ # if model_name in model_cache:
1329
+ # return model_cache[model_name]
1330
+ # config = MODEL_OPTIONS[model_name]
1331
+ # if config["architecture"] == "UnetPlusPlus":
1332
+ # model = smp.UnetPlusPlus(
1333
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1334
+ # decoder_channels=(256, 128, 64, 32, 16),
1335
+ # in_channels=1, classes=1, activation=None)
1336
+ # elif config["architecture"] == "Unet":
1337
+ # model = smp.Unet(
1338
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1339
+ # decoder_channels=(256, 128, 64, 32, 16),
1340
+ # in_channels=1, classes=1, activation=None)
1341
+ # elif config["architecture"] == "DeepLabV3Plus":
1342
+ # model = smp.DeepLabV3Plus(
1343
+ # encoder_name=config["encoder"], encoder_weights="imagenet",
1344
+ # in_channels=1, classes=1, activation=None)
1345
+ # else:
1346
+ # raise ValueError("Unsupported architecture.")
1347
+ # model.load_state_dict(torch.load(config["path"], map_location=device))
1348
+ # model.eval()
1349
+ # model_cache[model_name] = model.to(device)
1350
+ # return model_cache[model_name]
1351
+
1352
+ # def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size):
1353
+ # """
1354
+ # Draws numbers on the overlay image labeling oligomers and fibrils.
1355
+
1356
+ # Oligomers labeled as O1, O2, ...
1357
+ # Fibrils labeled as F1, F2, ...
1358
+ # """
1359
+ # overlay = orig_img.convert("RGB").copy()
1360
+ # draw = ImageDraw.Draw(overlay)
1361
+
1362
+ # # Try to get a nice font; fallback to default if not available
1363
+ # try:
1364
+ # font = ImageFont.truetype("arial.ttf", 18)
1365
+ # except IOError:
1366
+ # font = ImageFont.load_default()
1367
+
1368
+ # labeled_mask = measure.label(binary_mask)
1369
+ # regions = measure.regionprops(labeled_mask)
1370
+
1371
+ # oligomer_count = 0
1372
+ # fibril_count = 0
1373
+
1374
+ # for region in regions:
1375
+ # area = region.area
1376
+ # centroid = region.centroid # (row, col)
1377
+ # x, y = int(centroid[1]), int(centroid[0])
1378
+
1379
+ # if area <= max_oligomer_size:
1380
+ # oligomer_count += 1
1381
+ # label_text = f"O{oligomer_count}"
1382
+ # label_color = (0, 255, 0) # Green for oligomers
1383
+ # else:
1384
+ # fibril_count += 1
1385
+ # label_text = f"F{fibril_count}"
1386
+ # label_color = (255, 0, 0) # Red for fibrils
1387
+
1388
+ # # Draw circle around centroid for visibility
1389
+ # r = 12
1390
+ # draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2)
1391
+ # # Draw label text
1392
+ # bbox = draw.textbbox((0, 0), label_text, font=font)
1393
+ # text_width = bbox[2] - bbox[0]
1394
+ # text_height = bbox[3] - bbox[1]
1395
+ # text_pos = (x - text_width // 2, y - text_height // 2)
1396
+ # draw.text(text_pos, label_text, fill=label_color, font=font)
1397
+
1398
+ # return overlay, oligomer_count, fibril_count
1399
+
1400
+ # @torch.no_grad()
1401
+ # def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers):
1402
+ # if image is None:
1403
+ # return "❌ Please upload an image.", None, None, None, "", "", "", ""
1404
+
1405
+ # image = image.convert("L")
1406
+ # img_np = np.array(image)
1407
+ # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
1408
+
1409
+ # model = load_model(model_name)
1410
+ # pred = torch.sigmoid(model(img_tensor)).cpu().squeeze().numpy()
1411
+
1412
+ # if use_otsu:
1413
+ # threshold = filters.threshold_otsu(pred)
1414
+
1415
+ # binary_mask = (pred > threshold).astype(np.float32)
1416
+
1417
+ # if remove_noise:
1418
+ # binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
1419
+ # if fill_holes:
1420
+ # binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
1421
+
1422
+ # binary_mask = binary_mask.astype(np.float32)
1423
+
1424
+ # # If user wants to keep only oligomers, remove large fibrils from mask
1425
+ # if keep_only_oligomers:
1426
+ # labeled_mask = measure.label(binary_mask)
1427
+ # regions = measure.regionprops(labeled_mask)
1428
+ # filtered_mask = np.zeros_like(binary_mask)
1429
+ # for region in regions:
1430
+ # if region.area <= max_oligomer_size:
1431
+ # filtered_mask[labeled_mask == region.label] = 1
1432
+ # binary_mask = filtered_mask.astype(np.float32)
1433
+
1434
+ # mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
1435
+ # prob_img = Image.fromarray((pred * 255).astype(np.uint8))
1436
+
1437
+ # overlay_img = None
1438
+ # oligomer_count = 0
1439
+ # fibril_count = 0
1440
+
1441
+ # if show_overlay:
1442
+ # # Resize mask to original image size
1443
+ # mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
1444
+ # # Blend original and mask
1445
+ # base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
1446
+
1447
+ # # Draw labels on overlay
1448
+ # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size)
1449
+ # else:
1450
+ # overlay_img = None
1451
+
1452
+ # stats_text = ""
1453
+ # if show_stats:
1454
+ # labeled_mask = measure.label(binary_mask)
1455
+ # area = np.sum(binary_mask)
1456
+ # mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
1457
+ # std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
1458
+ # total_objects = labeled_mask.max()
1459
+
1460
+ # # Count oligomers and fibrils
1461
+ # oligomers = 0
1462
+ # fibrils = 0
1463
+ # for region in measure.regionprops(labeled_mask):
1464
+ # if region.area <= max_oligomer_size:
1465
+ # oligomers += 1
1466
+ # else:
1467
+ # fibrils += 1
1468
+
1469
+ # stats_text = (
1470
+ # f"🧮 Stats:\n"
1471
+ # f" - Area (px): {area:.0f}\n"
1472
+ # f" - Total Objects: {total_objects}\n"
1473
+ # f" - Oligomers (small): {oligomers}\n"
1474
+ # f" - Fibrils (large): {fibrils}\n"
1475
+ # f" - Mean Confidence: {mean_conf:.3f}\n"
1476
+ # f" - Std Confidence: {std_conf:.3f}"
1477
+ # )
1478
+
1479
+ # mask_img.save("predicted_mask.png")
1480
+
1481
+ # return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count)
1482
+
1483
+ # css = """
1484
+ # body {
1485
+ # background: #f9fafb;
1486
+ # color: #2c3e50;
1487
+ # font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
1488
+ # }
1489
+ # h1, h2, h3 {
1490
+ # color: #34495e;
1491
+ # margin-bottom: 0.2em;
1492
+ # }
1493
+ # .gradio-container {
1494
+ # max-width: 1100px;
1495
+ # margin: 1.5rem auto;
1496
+ # padding: 1rem 2rem;
1497
+ # }
1498
+ # .gr-button {
1499
+ # background-color: #0078d7;
1500
+ # color: white;
1501
+ # font-weight: 600;
1502
+ # border-radius: 8px;
1503
+ # padding: 12px 25px;
1504
+ # }
1505
+ # .gr-button:hover {
1506
+ # background-color: #005a9e;
1507
+ # }
1508
+ # .gr-slider label, .gr-checkbox label {
1509
+ # font-weight: 600;
1510
+ # color: #34495e;
1511
+ # }
1512
+ # .gr-image input[type="file"] {
1513
+ # border-radius: 8px;
1514
+ # }
1515
+ # .gr-file label {
1516
+ # font-weight: 600;
1517
+ # }
1518
+ # .gr-textbox textarea {
1519
+ # font-family: monospace;
1520
+ # font-size: 0.9rem;
1521
+ # background: #ecf0f1;
1522
+ # border-radius: 6px;
1523
+ # padding: 8px;
1524
+ # }
1525
+ # """
1526
+
1527
+ # with gr.Blocks(css=css) as demo:
1528
+ # gr.Markdown("<h1 style='text-align:center; margin-bottom:0.25em;'>🧬 Fibril Segmentation Interface</h1>")
1529
+ # gr.Markdown("<p style='text-align:center; font-size:1.1rem; color:#555; margin-top:0; margin-bottom:2em;'>Upload a grayscale microscopy image and segment fibrillar structures with advanced deep learning models.</p>")
1530
+
1531
+ # with gr.Row():
1532
+ # with gr.Column(scale=1):
1533
+ # input_img = gr.Image(label="Upload Grayscale Image", type="pil", interactive=True, elem_id="input-img", sources=["upload"])
1534
+ # gr.Examples(
1535
+ # examples=[[f"examples/example{i}.jpg"] for i in range(1, 8)],
1536
+ # inputs=input_img,
1537
+ # label="📁 Try Example Images",
1538
+ # cache_examples=False,
1539
+ # elem_id="examples"
1540
+ # )
1541
+ # with gr.Column(scale=1):
1542
+ # model_selector = gr.Dropdown(
1543
+ # choices=list(MODEL_OPTIONS.keys()),
1544
+ # value="UNet++ (ResNet34)",
1545
+ # label="Select Model",
1546
+ # interactive=True
1547
+ # )
1548
+ # model_info = gr.Textbox(
1549
+ # label="Model Description",
1550
+ # interactive=False,
1551
+ # lines=3,
1552
+ # max_lines=5,
1553
+ # elem_id="model-desc",
1554
+ # show_label=True,
1555
+ # container=True
1556
+ # )
1557
+
1558
+ # def update_model_info(name):
1559
+ # return MODEL_OPTIONS[name]["description"]
1560
+ # model_selector.change(fn=update_model_info, inputs=model_selector, outputs=model_info)
1561
+
1562
+ # gr.Markdown("### Segmentation Options")
1563
+ # threshold_slider = gr.Slider(
1564
+ # minimum=0, maximum=1, value=0.5, step=0.01,
1565
+ # label="Segmentation Threshold",
1566
+ # interactive=True,
1567
+ # info="Adjust threshold for binarizing segmentation probability."
1568
+ # )
1569
+ # use_otsu = gr.Checkbox(
1570
+ # label="Use Otsu Threshold",
1571
+ # value=False,
1572
+ # info="Automatically select optimal threshold using Otsu's method."
1573
+ # )
1574
+ # remove_noise = gr.Checkbox(
1575
+ # label="Remove Small Objects",
1576
+ # value=False,
1577
+ # info="Remove small noise blobs from mask."
1578
+ # )
1579
+ # fill_holes = gr.Checkbox(
1580
+ # label="Fill Holes in Mask",
1581
+ # value=True,
1582
+ # info="Fill small holes inside segmented objects."
1583
+ # )
1584
+ # show_overlay = gr.Checkbox(
1585
+ # label="Show Overlay on Original",
1586
+ # value=True,
1587
+ # info="Display the mask overlaid on the original image."
1588
+ # )
1589
+ # show_stats = gr.Checkbox(
1590
+ # label="Show Area & Confidence Stats",
1591
+ # value=True,
1592
+ # info="Display segmentation statistics like area and confidence."
1593
+ # )
1594
+
1595
+ # max_oligomer_size = gr.Slider(
1596
+ # minimum=10, maximum=1000, value=400, step=10,
1597
+ # label="Max Oligomer Size (px)",
1598
+ # interactive=True,
1599
+ # info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils."
1600
+ # )
1601
+ # keep_only_oligomers = gr.Checkbox(
1602
+ # label="Keep Only Oligomers (Remove Large Fibrils)",
1603
+ # value=False,
1604
+ # info="If enabled, only oligomers remain in the final mask."
1605
+ # )
1606
+
1607
+ # submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
1608
+
1609
+ # gr.Markdown("---")
1610
+
1611
+ # status = gr.Textbox(label="Status", interactive=False, lines=1, elem_id="status-msg")
1612
+
1613
+ # with gr.Row():
1614
+ # mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
1615
+ # prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
1616
+ # overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil")
1617
+
1618
+ # stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats")
1619
+
1620
+ # # Optional: show counts separately with big text
1621
+ # oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1)
1622
+ # fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1)
1623
+
1624
+ # file_output = gr.File(label="Download Segmentation Mask")
1625
+
1626
+ # submit.click(
1627
+ # fn=predict,
1628
+ # inputs=[
1629
+ # input_img, model_selector, threshold_slider, use_otsu,
1630
+ # remove_noise, fill_holes, show_overlay, show_stats,
1631
+ # max_oligomer_size, keep_only_oligomers
1632
+ # ],
1633
+ # outputs=[
1634
+ # status, mask_output, prob_output, overlay_output,
1635
+ # file_output, stats_output, oligomer_count_txt, fibril_count_txt
1636
+ # ]
1637
+ # )
1638
+
1639
+ # demo.load(fn=update_model_info, inputs=model_selector, outputs=model_info)
1640
+
1641
+ # if __name__ == "__main__":
1642
+ # demo.launch()
1643
+
1644
+
1645
+
1646
+ # +++++++++++++++ Final Version: 1.8.0 ++++++++++++++++++++++
1647
+ # Last Updated: 10 July 2025
1648
+ # Improvements: Added model selection, better UI, and device handling, oligomer labeling and fibril labeling
1649
+ # +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
1650
+
1651
+
1652
  import os
1653
  import torch
1654
  import numpy as np
1655
+ from PIL import Image, ImageDraw, ImageFont
1656
  import albumentations as A
1657
  from albumentations.pytorch import ToTensorV2
1658
  import segmentation_models_pytorch as smp
 
1723
  model_cache[model_name] = model.to(device)
1724
  return model_cache[model_name]
1725
 
1726
+ def draw_labels_on_image(orig_img, binary_mask, max_oligomer_size, fibril_length_thresh=100):
1727
+ """
1728
+ Draw labels on the overlay image based on circularity and length.
1729
+
1730
+ Oligomers: High circularity (circle-like).
1731
+ Fibrils: Long objects (based on major axis length).
1732
+
1733
+ fibril_length_thresh: Length threshold to define fibrils.
1734
+ """
1735
+ overlay = orig_img.convert("RGB").copy()
1736
+ draw = ImageDraw.Draw(overlay)
1737
+
1738
+ try:
1739
+ font = ImageFont.truetype("arial.ttf", 18)
1740
+ except IOError:
1741
+ font = ImageFont.load_default()
1742
+
1743
+ labeled_mask = measure.label(binary_mask)
1744
+ regions = measure.regionprops(labeled_mask)
1745
+
1746
+ oligomer_count = 0
1747
+ fibril_count = 0
1748
+
1749
+ for region in regions:
1750
+ area = region.area
1751
+ perimeter = region.perimeter if region.perimeter > 0 else 1 # prevent div by zero
1752
+ circularity = 4 * np.pi * area / (perimeter ** 2)
1753
+ major_length = region.major_axis_length
1754
+
1755
+ centroid = region.centroid
1756
+ x, y = int(centroid[1]), int(centroid[0])
1757
+
1758
+ # Thresholds to tune
1759
+ circularity_thresh = 1 # close to circle
1760
+ # max_oligomer_size can still be used for area-based filtering if desired
1761
+
1762
+ # Classification logic
1763
+ if circularity >= circularity_thresh and area <= max_oligomer_size:
1764
+ oligomer_count += 1
1765
+ label_text = f"O{oligomer_count}"
1766
+ label_color = (0, 255, 0) # Green for oligomers
1767
+ elif major_length >= fibril_length_thresh:
1768
+ fibril_count += 1
1769
+ label_text = f"F{fibril_count}"
1770
+ label_color = (255, 0, 0) # Red for fibrils
1771
+ else:
1772
+ # If neither circular nor long, you can optionally skip labeling or classify as fibril
1773
+ fibril_count += 1
1774
+ label_text = f"F{fibril_count}"
1775
+ label_color = (255, 0, 0)
1776
+
1777
+ r = 12
1778
+ draw.ellipse((x-r, y-r, x+r, y+r), outline=label_color, width=2)
1779
+
1780
+ bbox = draw.textbbox((0, 0), label_text, font=font)
1781
+ text_width = bbox[2] - bbox[0]
1782
+ text_height = bbox[3] - bbox[1]
1783
+ text_pos = (x - text_width // 2, y - text_height // 2)
1784
+ draw.text(text_pos, label_text, fill=label_color, font=font)
1785
+
1786
+ return overlay, oligomer_count, fibril_count
1787
+
1788
  @torch.no_grad()
1789
+ def predict(image, model_name, threshold, use_otsu, remove_noise, fill_holes, show_overlay, show_stats, max_oligomer_size, keep_only_oligomers):
1790
  if image is None:
1791
+ return "❌ Please upload an image.", None, None, None, "", "", "", ""
1792
 
1793
  image = image.convert("L")
1794
  img_np = np.array(image)
 
1806
  binary_mask = morphology.remove_small_objects(binary_mask > 0, 64)
1807
  if fill_holes:
1808
  binary_mask = morphology.remove_small_holes(binary_mask > 0, 64)
1809
+
1810
  binary_mask = binary_mask.astype(np.float32)
1811
 
1812
+ # If user wants to keep only oligomers, remove large fibrils from mask
1813
+ if keep_only_oligomers:
1814
+ labeled_mask = measure.label(binary_mask)
1815
+ regions = measure.regionprops(labeled_mask)
1816
+ filtered_mask = np.zeros_like(binary_mask)
1817
+ for region in regions:
1818
+ if region.area <= max_oligomer_size:
1819
+ filtered_mask[labeled_mask == region.label] = 1
1820
+ binary_mask = filtered_mask.astype(np.float32)
1821
+
1822
  mask_img = Image.fromarray((binary_mask * 255).astype(np.uint8))
1823
  prob_img = Image.fromarray((pred * 255).astype(np.uint8))
1824
 
1825
+ overlay_img = None
1826
+ oligomer_count = 0
1827
+ fibril_count = 0
1828
+
1829
  if show_overlay:
1830
+ # Resize mask to original image size
1831
  mask_resized = mask_img.resize(image.size, resample=Image.NEAREST).convert("RGB")
1832
+ # Blend original and mask
1833
+ base_overlay = Image.blend(image.convert("RGB"), mask_resized, alpha=0.4)
1834
+
1835
+ # Draw labels on overlay
1836
+ # overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size)
1837
+ overlay_img, oligomer_count, fibril_count = draw_labels_on_image(base_overlay, np.array(mask_img) > 0, max_oligomer_size, fibril_length_thresh=100)
1838
  else:
1839
  overlay_img = None
1840
 
 
1844
  area = np.sum(binary_mask)
1845
  mean_conf = np.mean(pred[binary_mask > 0]) if area > 0 else 0
1846
  std_conf = np.std(pred[binary_mask > 0]) if area > 0 else 0
1847
+ total_objects = labeled_mask.max()
1848
+
1849
+ # Count oligomers and fibrils
1850
+ oligomers = 0
1851
+ fibrils = 0
1852
+ for region in measure.regionprops(labeled_mask):
1853
+ if region.area <= max_oligomer_size:
1854
+ oligomers += 1
1855
+ else:
1856
+ fibrils += 1
1857
+
1858
  stats_text = (
1859
+ f"🧮 Stats:\n"
1860
+ f" - Area (px): {area:.0f}\n"
1861
+ f" - Total Objects: {total_objects}\n"
1862
+ f" - Oligomers (small): {oligomers}\n"
1863
+ f" - Fibrils (large): {fibrils}\n"
1864
+ f" - Mean Confidence: {mean_conf:.3f}\n"
1865
+ f" - Std Confidence: {std_conf:.3f}"
1866
  )
1867
 
1868
  mask_img.save("predicted_mask.png")
1869
 
1870
+ return "✅ Segmentation Complete!", mask_img, prob_img, overlay_img, "predicted_mask.png", stats_text, str(oligomer_count), str(fibril_count)
 
1871
 
1872
  css = """
1873
  body {
 
1981
  info="Display segmentation statistics like area and confidence."
1982
  )
1983
 
1984
+ max_oligomer_size = gr.Slider(
1985
+ minimum=10, maximum=1000, value=400, step=10,
1986
+ label="Max Oligomer Size (px)",
1987
+ interactive=True,
1988
+ info="Objects smaller or equal to this area (in pixels) are oligomers; larger are fibrils."
1989
+ )
1990
+ keep_only_oligomers = gr.Checkbox(
1991
+ label="Keep Only Oligomers (Remove Large Fibrils)",
1992
+ value=False,
1993
+ info="If enabled, only oligomers remain in the final mask."
1994
+ )
1995
+
1996
  submit = gr.Button("🟢 Segment Image", variant="primary", elem_id="submit-btn")
1997
 
1998
  gr.Markdown("---")
 
2002
  with gr.Row():
2003
  mask_output = gr.Image(label="Binary Mask", interactive=False, type="pil")
2004
  prob_output = gr.Image(label="Confidence Map", interactive=False, type="pil")
2005
+ overlay_output = gr.Image(label="Overlay with Labels", interactive=False, type="pil")
2006
+
2007
+ stats_output = gr.Textbox(label="Segmentation Stats", interactive=False, lines=8, elem_id="stats")
2008
 
2009
+ # Optional: show counts separately with big text
2010
+ oligomer_count_txt = gr.Textbox(label="Oligomer Count", interactive=False, lines=1)
2011
+ fibril_count_txt = gr.Textbox(label="Fibril Count", interactive=False, lines=1)
2012
 
2013
  file_output = gr.File(label="Download Segmentation Mask")
2014
 
 
2016
  fn=predict,
2017
  inputs=[
2018
  input_img, model_selector, threshold_slider, use_otsu,
2019
+ remove_noise, fill_holes, show_overlay, show_stats,
2020
+ max_oligomer_size, keep_only_oligomers
2021
  ],
2022
  outputs=[
2023
  status, mask_output, prob_output, overlay_output,
2024
+ file_output, stats_output, oligomer_count_txt, fibril_count_txt
2025
  ]
2026
  )
2027
 
predicted_mask.png CHANGED