| """ models from the segmentation_models_pytorch library |
| install the library before using these models |
| """ |
| import torch |
| import torch.nn as nn |
| import segmentation_models_pytorch as smp |
|
|
|
|
| class SegModel(nn.Module): |
| def __init__(self, architecture='Unet', |
| encoder_name='mobilenet_v2', |
| encoder_weights='imagenet', |
| encoder_depth=4, |
| in_channels=3, |
| classes=4): |
| super().__init__() |
|
|
| """ |
| architecture = { |
| Unet, UnetPlusPlus, EfficientUnetPlusPlus, DeepLabV3, DeepLabV3+ |
| } |
| encoder_name = { |
| resnet18 / 11M, resnet34 / 21M, resnet50 / 23M, resnet101 / 42M |
| densenet121 / 6M, densenet169 / 12M, densenet201 / 18M |
| efficientnet-b0 / 4M, -b1 / 6M, -b2 / 7M, -b3 / 10M, -b4 / 17M, -b5 / 28M |
| mobilenet_v2 / 2M |
| } |
| encoder_weights = { |
| "imagenet", None |
| } |
| in_channels = 3 |
| classes = 4 |
| """ |
|
|
| if encoder_depth == 4: |
| decoder_channels = (256, 128, 64, 32) |
|
|
| exec("""self.model = smp.%s( |
| encoder_name='%s', |
| encoder_weights='%s', |
| encoder_depth=%s, |
| decoder_channels=%s, |
| in_channels=%s, |
| classes=%s, |
| activation=None, |
| ) |
| """ % (architecture, encoder_name, encoder_weights, encoder_depth, decoder_channels, in_channels, classes)) |
|
|
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
|
|