File size: 2,191 Bytes
5acc7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import torch.nn as nn


# ============================================================================
# Activation Module
# ============================================================================

class Activation(nn.Module):
    """Activation wrapper that supports various activation functions"""
    def __init__(self, activation=None):
        super().__init__()
        
        if activation is None or activation == 'identity':
            self.activation = nn.Identity()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'softmax':
            self.activation = nn.Softmax(dim=1)
        elif activation == 'softmax2d':
            self.activation = nn.Softmax(dim=1)
        elif activation == 'logsoftmax':
            self.activation = nn.LogSoftmax(dim=1)
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        elif activation == 'relu':
            self.activation = nn.ReLU(inplace=True)
        elif callable(activation):
            self.activation = activation
        else:
            raise ValueError(
                f'Activation should be callable/sigmoid/softmax/logsoftmax/tanh/None; got {activation}'
            )
    
    def forward(self, x):
        return self.activation(x)
    
# ============================================================================
# Segmentation Head (nn.Sequential style)
# ============================================================================

class SegmentationHead(nn.Sequential):
    """Segmentation head using nn.Sequential style"""
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size=3,
        activation=None,
        upsampling=1
    ):
        conv2d = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel_size,
            padding=kernel_size // 2
        )
        upsampling_layer = (
            nn.UpsamplingBilinear2d(scale_factor=upsampling) 
            if upsampling > 1 
            else nn.Identity()
        )
        activation_layer = Activation(activation)
        super().__init__(conv2d, upsampling_layer, activation_layer)