himanshuch8055 commited on
Commit
e9bd06c
·
1 Parent(s): a700b86

Add support for UNet++ with EfficientNet-B3 and implement model caching

Browse files
app.py CHANGED
@@ -217,6 +217,11 @@ MODEL_OPTIONS = {
217
  "path": "./model/unet_fibril_seg_model.pth",
218
  "encoder": "resnet34",
219
  "architecture": "Unet"
 
 
 
 
 
220
  }
221
  }
222
 
@@ -231,7 +236,39 @@ def get_transform(size):
231
  transform = get_transform(512)
232
 
233
  # ─── Model Loader ──────────────────────────────────────────
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
  def load_model(model_name):
 
 
 
235
  config = MODEL_OPTIONS[model_name]
236
  if config["architecture"] == "UnetPlusPlus":
237
  model = smp.UnetPlusPlus(
@@ -256,7 +293,8 @@ def load_model(model_name):
256
 
257
  model.load_state_dict(torch.load(config["path"], map_location=device))
258
  model.eval()
259
- return model.to(device)
 
260
 
261
  # ─── Prediction Function ───────────────────────────────────
262
  def predict(image, model_name):
 
217
  "path": "./model/unet_fibril_seg_model.pth",
218
  "encoder": "resnet34",
219
  "architecture": "Unet"
220
+ },
221
+ "UNet++ (efficientnet-b3)": {
222
+ "path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
223
+ "encoder": "efficientnet-b3",
224
+ "architecture": "UnetPlusPlus"
225
  }
226
  }
227
 
 
236
  transform = get_transform(512)
237
 
238
  # ─── Model Loader ──────────────────────────────────────────
239
+ # def load_model(model_name):
240
+ # config = MODEL_OPTIONS[model_name]
241
+ # if config["architecture"] == "UnetPlusPlus":
242
+ # model = smp.UnetPlusPlus(
243
+ # encoder_name=config["encoder"],
244
+ # encoder_weights="imagenet",
245
+ # decoder_channels=(256, 128, 64, 32, 16),
246
+ # in_channels=1,
247
+ # classes=1,
248
+ # activation=None
249
+ # )
250
+ # elif config["architecture"] == "Unet":
251
+ # model = smp.Unet(
252
+ # encoder_name=config["encoder"],
253
+ # encoder_weights="imagenet",
254
+ # decoder_channels=(256, 128, 64, 32, 16),
255
+ # in_channels=1,
256
+ # classes=1,
257
+ # activation=None
258
+ # )
259
+ # else:
260
+ # raise ValueError(f"Unsupported architecture: {config['architecture']}")
261
+
262
+ # model.load_state_dict(torch.load(config["path"], map_location=device))
263
+ # model.eval()
264
+ # return model.to(device)
265
+
266
+ model_cache = {}
267
+
268
  def load_model(model_name):
269
+ if model_name in model_cache:
270
+ return model_cache[model_name]
271
+
272
  config = MODEL_OPTIONS[model_name]
273
  if config["architecture"] == "UnetPlusPlus":
274
  model = smp.UnetPlusPlus(
 
293
 
294
  model.load_state_dict(torch.load(config["path"], map_location=device))
295
  model.eval()
296
+ model_cache[model_name] = model.to(device)
297
+ return model_cache[model_name]
298
 
299
  # ─── Prediction Function ───────────────────────────────────
300
  def predict(image, model_name):
model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2170c5830e61a34edfca41e64e76f19fc41385b8517fbd4c01282115c5f7fed
3
+ size 55158635