Spaces:
Sleeping
Sleeping
Commit ·
e9bd06c
1
Parent(s): a700b86
Add support for UNet++ with EfficientNet-B3 and implement model caching
Browse files
app.py
CHANGED
|
@@ -217,6 +217,11 @@ MODEL_OPTIONS = {
|
|
| 217 |
"path": "./model/unet_fibril_seg_model.pth",
|
| 218 |
"encoder": "resnet34",
|
| 219 |
"architecture": "Unet"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
}
|
| 221 |
}
|
| 222 |
|
|
@@ -231,7 +236,39 @@ def get_transform(size):
|
|
| 231 |
transform = get_transform(512)
|
| 232 |
|
| 233 |
# ─── Model Loader ──────────────────────────────────────────
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
def load_model(model_name):
|
|
|
|
|
|
|
|
|
|
| 235 |
config = MODEL_OPTIONS[model_name]
|
| 236 |
if config["architecture"] == "UnetPlusPlus":
|
| 237 |
model = smp.UnetPlusPlus(
|
|
@@ -256,7 +293,8 @@ def load_model(model_name):
|
|
| 256 |
|
| 257 |
model.load_state_dict(torch.load(config["path"], map_location=device))
|
| 258 |
model.eval()
|
| 259 |
-
|
|
|
|
| 260 |
|
| 261 |
# ─── Prediction Function ───────────────────────────────────
|
| 262 |
def predict(image, model_name):
|
|
|
|
| 217 |
"path": "./model/unet_fibril_seg_model.pth",
|
| 218 |
"encoder": "resnet34",
|
| 219 |
"architecture": "Unet"
|
| 220 |
+
},
|
| 221 |
+
"UNet++ (efficientnet-b3)": {
|
| 222 |
+
"path": "./model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth",
|
| 223 |
+
"encoder": "efficientnet-b3",
|
| 224 |
+
"architecture": "UnetPlusPlus"
|
| 225 |
}
|
| 226 |
}
|
| 227 |
|
|
|
|
| 236 |
transform = get_transform(512)
|
| 237 |
|
| 238 |
# ─── Model Loader ──────────────────────────────────────────
|
| 239 |
+
# def load_model(model_name):
|
| 240 |
+
# config = MODEL_OPTIONS[model_name]
|
| 241 |
+
# if config["architecture"] == "UnetPlusPlus":
|
| 242 |
+
# model = smp.UnetPlusPlus(
|
| 243 |
+
# encoder_name=config["encoder"],
|
| 244 |
+
# encoder_weights="imagenet",
|
| 245 |
+
# decoder_channels=(256, 128, 64, 32, 16),
|
| 246 |
+
# in_channels=1,
|
| 247 |
+
# classes=1,
|
| 248 |
+
# activation=None
|
| 249 |
+
# )
|
| 250 |
+
# elif config["architecture"] == "Unet":
|
| 251 |
+
# model = smp.Unet(
|
| 252 |
+
# encoder_name=config["encoder"],
|
| 253 |
+
# encoder_weights="imagenet",
|
| 254 |
+
# decoder_channels=(256, 128, 64, 32, 16),
|
| 255 |
+
# in_channels=1,
|
| 256 |
+
# classes=1,
|
| 257 |
+
# activation=None
|
| 258 |
+
# )
|
| 259 |
+
# else:
|
| 260 |
+
# raise ValueError(f"Unsupported architecture: {config['architecture']}")
|
| 261 |
+
|
| 262 |
+
# model.load_state_dict(torch.load(config["path"], map_location=device))
|
| 263 |
+
# model.eval()
|
| 264 |
+
# return model.to(device)
|
| 265 |
+
|
| 266 |
+
model_cache = {}
|
| 267 |
+
|
| 268 |
def load_model(model_name):
|
| 269 |
+
if model_name in model_cache:
|
| 270 |
+
return model_cache[model_name]
|
| 271 |
+
|
| 272 |
config = MODEL_OPTIONS[model_name]
|
| 273 |
if config["architecture"] == "UnetPlusPlus":
|
| 274 |
model = smp.UnetPlusPlus(
|
|
|
|
| 293 |
|
| 294 |
model.load_state_dict(torch.load(config["path"], map_location=device))
|
| 295 |
model.eval()
|
| 296 |
+
model_cache[model_name] = model.to(device)
|
| 297 |
+
return model_cache[model_name]
|
| 298 |
|
| 299 |
# ─── Prediction Function ───────────────────────────────────
|
| 300 |
def predict(image, model_name):
|
model/encoder_efficientnet-b3_decoder_UnetPlusPlus_fibril_seg_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2170c5830e61a34edfca41e64e76f19fc41385b8517fbd4c01282115c5f7fed
|
| 3 |
+
size 55158635
|