himanshuch8055 commited on
Commit
a700b86
·
1 Parent(s): 3f1d82b

Enhance fibril segmentation app with model selection, improved UI, and device handling; add UNet model weights

Browse files
Files changed (2) hide show
  1. app.py +167 -26
  2. model/unet_fibril_seg_model.pth +3 -0
app.py CHANGED
@@ -81,6 +81,118 @@
81
  # # Last Updated: 08 July 2025
82
  # # Improvements: Added examples, better UI, and device handling
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  import os
85
  import torch
86
  import numpy as np
@@ -90,29 +202,23 @@ from albumentations.pytorch import ToTensorV2
90
  import segmentation_models_pytorch as smp
91
  import gradio as gr
92
 
93
- # ─── Configuration ─────────────────────────────────────────
94
- CONFIG = {
95
- "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
96
- "img_size": 512
97
- }
98
-
99
  # ─── Device Setup ──────────────────────────────────────────
100
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
  print(f"✅ Using device: {device}")
102
 
103
- # ─── Load Model ────────────────────────────────────────────
104
- model = smp.UnetPlusPlus(
105
- encoder_name='resnet34',
106
- encoder_depth=5,
107
- encoder_weights='imagenet',
108
- decoder_channels=(256, 128, 64, 32, 16),
109
- in_channels=1,
110
- classes=1,
111
- activation=None
112
- ).to(device)
113
-
114
- model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
115
- model.eval()
116
 
117
  # ─── Transform Function ────────────────────────────────────
118
  def get_transform(size):
@@ -122,14 +228,44 @@ def get_transform(size):
122
  ToTensorV2()
123
  ])
124
 
125
- transform = get_transform(CONFIG["img_size"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # ─── Prediction Function ───────────────────────────────────
128
- def predict(image):
129
- image = image.convert("L") # Ensure grayscale
130
  img_np = np.array(image)
131
  img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
132
 
 
 
133
  with torch.no_grad():
134
  pred = torch.sigmoid(model(img_tensor))
135
  mask = (pred > 0.5).float().cpu().squeeze().numpy()
@@ -137,7 +273,7 @@ def predict(image):
137
  mask_img = Image.fromarray((mask * 255).astype(np.uint8))
138
  return mask_img
139
 
140
- # ─── Gradio UI (Improved) ──────────────────────────────────
141
  examples = [
142
  ["examples/example1.jpg"],
143
  ["examples/example2.jpg"],
@@ -148,6 +284,7 @@ examples = [
148
  ["examples/example7.jpg"]
149
  ]
150
 
 
151
  css = """
152
  .gradio-container {
153
  max-width: 950px;
@@ -163,9 +300,13 @@ css = """
163
  }
164
  """
165
 
 
166
  with gr.Blocks(css=css) as demo:
167
- gr.Markdown("## 🧬 Fibril Segmentation with UNet++")
168
- gr.Markdown("Upload a **grayscale microscopy image**, and this model will predict the **segmentation mask of fibrillar structures**.\n\nModel: ResNet34 encoder + UNet++ decoder")
 
 
 
169
 
170
  with gr.Row():
171
  input_img = gr.Image(label="Upload Microscopy Image", type="pil")
@@ -173,7 +314,7 @@ with gr.Blocks(css=css) as demo:
173
 
174
  submit_btn = gr.Button("Segment Image")
175
 
176
- submit_btn.click(fn=predict, inputs=input_img, outputs=output_mask)
177
 
178
  gr.Examples(
179
  examples=examples,
 
81
  # # Last Updated: 08 July 2025
82
  # # Improvements: Added examples, better UI, and device handling
83
 
84
+ # import os
85
+ # import torch
86
+ # import numpy as np
87
+ # from PIL import Image
88
+ # import albumentations as A
89
+ # from albumentations.pytorch import ToTensorV2
90
+ # import segmentation_models_pytorch as smp
91
+ # import gradio as gr
92
+
93
+ # # ─── Configuration ─────────────────────────────────────────
94
+ # CONFIG = {
95
+ # "model_path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
96
+ # "img_size": 512
97
+ # }
98
+
99
+ # # ─── Device Setup ──────────────────────────────────────────
100
+ # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+ # print(f"✅ Using device: {device}")
102
+
103
+ # # ─── Load Model ────────────────────────────────────────────
104
+ # model = smp.UnetPlusPlus(
105
+ # encoder_name='resnet34',
106
+ # encoder_depth=5,
107
+ # encoder_weights='imagenet',
108
+ # decoder_channels=(256, 128, 64, 32, 16),
109
+ # in_channels=1,
110
+ # classes=1,
111
+ # activation=None
112
+ # ).to(device)
113
+
114
+ # model.load_state_dict(torch.load(CONFIG["model_path"], map_location=device))
115
+ # model.eval()
116
+
117
+ # # ─── Transform Function ────────────────────────────────────
118
+ # def get_transform(size):
119
+ # return A.Compose([
120
+ # A.Resize(size, size),
121
+ # A.Normalize(mean=(0.5,), std=(0.5,)),
122
+ # ToTensorV2()
123
+ # ])
124
+
125
+ # transform = get_transform(CONFIG["img_size"])
126
+
127
+ # # ─── Prediction Function ───────────────────────────────────
128
+ # def predict(image):
129
+ # image = image.convert("L") # Ensure grayscale
130
+ # img_np = np.array(image)
131
+ # img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
132
+
133
+ # with torch.no_grad():
134
+ # pred = torch.sigmoid(model(img_tensor))
135
+ # mask = (pred > 0.5).float().cpu().squeeze().numpy()
136
+
137
+ # mask_img = Image.fromarray((mask * 255).astype(np.uint8))
138
+ # return mask_img
139
+
140
+ # # ─── Gradio UI (Improved) ──────────────────────────────────
141
+ # examples = [
142
+ # ["examples/example1.jpg"],
143
+ # ["examples/example2.jpg"],
144
+ # ["examples/example3.jpg"],
145
+ # ["examples/example4.jpg"],
146
+ # ["examples/example5.jpg"],
147
+ # ["examples/example6.jpg"],
148
+ # ["examples/example7.jpg"]
149
+ # ]
150
+
151
+ # css = """
152
+ # .gradio-container {
153
+ # max-width: 950px;
154
+ # margin: auto;
155
+ # }
156
+ # .gr-button {
157
+ # background-color: #4a90e2;
158
+ # color: white;
159
+ # border-radius: 5px;
160
+ # }
161
+ # .gr-button:hover {
162
+ # background-color: #357ABD;
163
+ # }
164
+ # """
165
+
166
+ # with gr.Blocks(css=css) as demo:
167
+ # gr.Markdown("## 🧬 Fibril Segmentation with UNet++")
168
+ # gr.Markdown("Upload a **grayscale microscopy image**, and this model will predict the **segmentation mask of fibrillar structures**.\n\nModel: ResNet34 encoder + UNet++ decoder")
169
+
170
+ # with gr.Row():
171
+ # input_img = gr.Image(label="Upload Microscopy Image", type="pil")
172
+ # output_mask = gr.Image(label="Predicted Segmentation Mask", type="pil")
173
+
174
+ # submit_btn = gr.Button("Segment Image")
175
+
176
+ # submit_btn.click(fn=predict, inputs=input_img, outputs=output_mask)
177
+
178
+ # gr.Examples(
179
+ # examples=examples,
180
+ # inputs=input_img,
181
+ # label="Try with Example Images",
182
+ # cache_examples=False
183
+ # )
184
+
185
+ # # ─── Launch App ────────────────────────────────────────────
186
+ # if __name__ == "__main__":
187
+ # demo.launch()
188
+
189
+
190
+
191
+ # +++++++++++++++ Final Version: 1.2.0 ++++++++++++++++++++++
192
+ # Last Updated: 08 July 2025
193
+ # Improvements: Added model selection, better UI, and device handling
194
+
195
+
196
  import os
197
  import torch
198
  import numpy as np
 
202
  import segmentation_models_pytorch as smp
203
  import gradio as gr
204
 
 
 
 
 
 
 
205
  # ─── Device Setup ──────────────────────────────────────────
206
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
  print(f"✅ Using device: {device}")
208
 
209
+ # ─── Model Configurations ──────────────────────────────────
210
+ MODEL_OPTIONS = {
211
+ "UNet++ (ResNet34)": {
212
+ "path": "./model/encoder_resnet34_decoder_UnetPlusPlus_fibril_seg_model.pth",
213
+ "encoder": "resnet34",
214
+ "architecture": "UnetPlusPlus"
215
+ },
216
+ "UNet (ResNet34)": {
217
+ "path": "./model/unet_fibril_seg_model.pth",
218
+ "encoder": "resnet34",
219
+ "architecture": "Unet"
220
+ }
221
+ }
222
 
223
  # ─── Transform Function ────────────────────────────────────
224
  def get_transform(size):
 
228
  ToTensorV2()
229
  ])
230
 
231
+ transform = get_transform(512)
232
+
233
+ # ─── Model Loader ──────────────────────────────────────────
234
+ def load_model(model_name):
235
+ config = MODEL_OPTIONS[model_name]
236
+ if config["architecture"] == "UnetPlusPlus":
237
+ model = smp.UnetPlusPlus(
238
+ encoder_name=config["encoder"],
239
+ encoder_weights="imagenet",
240
+ decoder_channels=(256, 128, 64, 32, 16),
241
+ in_channels=1,
242
+ classes=1,
243
+ activation=None
244
+ )
245
+ elif config["architecture"] == "Unet":
246
+ model = smp.Unet(
247
+ encoder_name=config["encoder"],
248
+ encoder_weights="imagenet",
249
+ decoder_channels=(256, 128, 64, 32, 16),
250
+ in_channels=1,
251
+ classes=1,
252
+ activation=None
253
+ )
254
+ else:
255
+ raise ValueError(f"Unsupported architecture: {config['architecture']}")
256
+
257
+ model.load_state_dict(torch.load(config["path"], map_location=device))
258
+ model.eval()
259
+ return model.to(device)
260
 
261
  # ─── Prediction Function ───────────────────────────────────
262
+ def predict(image, model_name):
263
+ image = image.convert("L")
264
  img_np = np.array(image)
265
  img_tensor = transform(image=img_np)["image"].unsqueeze(0).to(device)
266
 
267
+ model = load_model(model_name)
268
+
269
  with torch.no_grad():
270
  pred = torch.sigmoid(model(img_tensor))
271
  mask = (pred > 0.5).float().cpu().squeeze().numpy()
 
273
  mask_img = Image.fromarray((mask * 255).astype(np.uint8))
274
  return mask_img
275
 
276
+ # ─── Example Images ────────────────────────────────────────
277
  examples = [
278
  ["examples/example1.jpg"],
279
  ["examples/example2.jpg"],
 
284
  ["examples/example7.jpg"]
285
  ]
286
 
287
+ # ─── Custom CSS ────────────────────────────────────────────
288
  css = """
289
  .gradio-container {
290
  max-width: 950px;
 
300
  }
301
  """
302
 
303
+ # ─── Gradio UI ─────────────────────────────────────────────
304
  with gr.Blocks(css=css) as demo:
305
+ gr.Markdown("## 🧬 Fibril Segmentation Interface")
306
+ gr.Markdown("Choose a model and upload a grayscale microscopy image. The model will predict the **fibrillar structure mask**.")
307
+
308
+ with gr.Row():
309
+ model_selector = gr.Dropdown(choices=list(MODEL_OPTIONS.keys()), value="UNet++ (ResNet34)", label="Select Model")
310
 
311
  with gr.Row():
312
  input_img = gr.Image(label="Upload Microscopy Image", type="pil")
 
314
 
315
  submit_btn = gr.Button("Segment Image")
316
 
317
+ submit_btn.click(fn=predict, inputs=[input_img, model_selector], outputs=output_mask)
318
 
319
  gr.Examples(
320
  examples=examples,
model/unet_fibril_seg_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39aabaff5e7006840147b16e67ca995fee643742b3d44b3ade469485384fd153
3
+ size 97898267