RuthvikBandari commited on
Commit
9696169
·
verified ·
1 Parent(s): 88a3f32

Upload src/models/unetpp.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src/models/unetpp.py +41 -0
src/models/unetpp.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """DiaFoot.AI v2 — U-Net++ via Segmentation Models PyTorch.
2
+
3
+ Phase 2, Commit 9: Baseline single-task segmentation model.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import segmentation_models_pytorch as smp
9
+ import torch.nn as nn # noqa: TC002
10
+
11
+
12
+ def build_unetpp(
13
+ encoder_name: str = "efficientnet-b4",
14
+ encoder_weights: str | None = "imagenet",
15
+ in_channels: int = 3,
16
+ classes: int = 1,
17
+ decoder_attention_type: str | None = "scse",
18
+ deep_supervision: bool = True,
19
+ ) -> nn.Module:
20
+ """Build a U-Net++ model via SMP.
21
+
22
+ Args:
23
+ encoder_name: Encoder backbone name (from timm/SMP).
24
+ encoder_weights: Pretrained weights ('imagenet' or None).
25
+ in_channels: Number of input channels.
26
+ classes: Number of output segmentation classes.
27
+ decoder_attention_type: Attention type ('scse' or None).
28
+ deep_supervision: Enable deep supervision for better gradients.
29
+
30
+ Returns:
31
+ SMP UnetPlusPlus model.
32
+ """
33
+ return smp.UnetPlusPlus(
34
+ encoder_name=encoder_name,
35
+ encoder_weights=encoder_weights,
36
+ in_channels=in_channels,
37
+ classes=classes,
38
+ decoder_attention_type=decoder_attention_type,
39
+ encoder_depth=5,
40
+ decoder_channels=(256, 128, 64, 32, 16),
41
+ )