File size: 1,868 Bytes
84b07bc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
"""

Baseline Model Factory



This module provides a factory function to create baseline segmentation models

using the segmentation-models-pytorch library.



Supported architectures:

- Unet

- UnetPlusPlus

- PSPNet

- PAN

- MAnet

- Linknet

- FPN

- DeepLabV3Plus

- DeepLabV3

"""

import torch
import segmentation_models_pytorch as smp


def create_model(architecture='Unet', encoder_name='resnet50', in_channels=3, 

                 num_classes=1, encoder_weights='imagenet'):
    """

    Create a baseline segmentation model using segmentation-models-pytorch.

    

    Args:

        architecture (str): Model architecture name. One of:

            'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet', 

            'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'

        encoder_name (str): Encoder backbone name (default: 'resnet50')

        in_channels (int): Number of input channels (default: 3)

        num_classes (int): Number of output classes (default: 1)

        encoder_weights (str): Pretrained weights source (default: 'imagenet')

        

    Returns:

        nn.Module: Segmentation model on appropriate device

        

    Example:

        >>> model = create_model('DeepLabV3Plus', encoder_name='resnet50')

        >>> x = torch.randn(1, 3, 448, 448)

        >>> output = model(x)

        >>> output.shape

        torch.Size([1, 1, 448, 448])

    """
    model = getattr(smp, architecture)(
        encoder_name=encoder_name,
        encoder_weights=encoder_weights,
        in_channels=in_channels,
        classes=num_classes
    )
    return model.to(torch.device('cuda' if torch.cuda.is_available() else 'cpu'))


# List of supported architectures for reference
SUPPORTED_ARCHITECTURES = [
    'Unet', 'UnetPlusPlus', 'PSPNet', 'PAN', 'MAnet', 
    'Linknet', 'FPN', 'DeepLabV3Plus', 'DeepLabV3'
]