File size: 1,546 Bytes
6b92ff7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
""" models from the segmentation_models_pytorch library
    install the library before using these models
"""
import torch
import torch.nn as nn
import segmentation_models_pytorch as smp


class SegModel(nn.Module):
    def __init__(self, architecture='Unet', 
                    encoder_name='mobilenet_v2', 
                    encoder_weights='imagenet',
                    encoder_depth=4,
                    in_channels=3, 
                    classes=4):
        super().__init__()

        """
        architecture = {
            Unet, UnetPlusPlus, EfficientUnetPlusPlus, DeepLabV3, DeepLabV3+
        }
        encoder_name = {
            resnet18 / 11M, resnet34 / 21M, resnet50 / 23M, resnet101 / 42M
            densenet121 / 6M, densenet169 / 12M, densenet201 / 18M
            efficientnet-b0 / 4M, -b1 / 6M, -b2 / 7M, -b3 / 10M, -b4 / 17M, -b5 / 28M
            mobilenet_v2 / 2M
        }
        encoder_weights = {
            "imagenet", None
        }
        in_channels = 3
        classes = 4
        """

        if encoder_depth == 4:
            decoder_channels = (256, 128, 64, 32)

        exec("""self.model = smp.%s(
            encoder_name='%s',
            encoder_weights='%s',
            encoder_depth=%s,
            decoder_channels=%s,
            in_channels=%s,
            classes=%s,
            activation=None,
        )
        """ % (architecture, encoder_name, encoder_weights, encoder_depth, decoder_channels, in_channels, classes))


    def forward(self, x):
        return self.model(x)