File size: 5,839 Bytes
853e22b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import torch
from torch import nn

from monai.utils import optional_import

if optional_import("torch.nn.functional", name="mish")[1]:

    def monai_mish(x, inplace: bool = False):
        return torch.nn.functional.mish(x, inplace=inplace)

else:

    def monai_mish(x, inplace: bool = False):
        return x * torch.tanh(torch.nn.functional.softplus(x))


if optional_import("torch.nn.functional", name="silu")[1]:

    def monai_swish(x, inplace: bool = False):
        return torch.nn.functional.silu(x, inplace=inplace)

else:

    def monai_swish(x, inplace: bool = False):
        return SwishImplementation.apply(x)


class Swish(nn.Module):
    r"""Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha.

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.


    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['swish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    """

    def __init__(self, alpha=1.0):
        super().__init__()
        self.alpha = alpha

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        return input * torch.sigmoid(self.alpha * input)


class SwishImplementation(torch.autograd.Function):
    r"""Memory efficient implementation for training
    Follows recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()
    """

    @staticmethod
    def forward(ctx, input):
        result = input * torch.sigmoid(input)
        ctx.save_for_backward(input)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        input = ctx.saved_tensors[0]
        sigmoid_input = torch.sigmoid(input)
        return grad_output * (sigmoid_input * (1 + input * (1 - sigmoid_input)))


class MemoryEfficientSwish(nn.Module):
    r"""Applies the element-wise function:

    .. math::
        \text{Swish}(x) = x * \text{Sigmoid}(\alpha * x) ~~~~\text{for constant value}~ \alpha=1.

    Memory efficient implementation for training following recommendation from:
    https://github.com/lukemelas/EfficientNet-PyTorch/issues/18#issuecomment-511677853

    Results in ~ 30% memory saving during training as compared to Swish()

    Citation: Searching for Activation Functions, Ramachandran et al., 2017, https://arxiv.org/abs/1710.05941.

    From Pytorch 1.7.0+, the optimized version of `Swish` named `SiLU` is implemented,
    this class will utilize `torch.nn.functional.silu` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['memswish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    """

    def __init__(self, inplace: bool = False):
        super().__init__()
        # inplace only works when using torch.nn.functional.silu
        self.inplace = inplace

    def forward(self, input: torch.Tensor):
        return monai_swish(input, self.inplace)


class Mish(nn.Module):
    r"""Applies the element-wise function:

    .. math::
        \text{Mish}(x) = x * tanh(\text{softplus}(x)).

    Citation: Mish: A Self Regularized Non-Monotonic Activation Function, Diganta Misra, 2019, https://arxiv.org/abs/1908.08681.

    From Pytorch 1.9.0+, the optimized version of `Mish` is implemented,
    this class will utilize `torch.nn.functional.mish` to do the calculation if meets the version.

    Shape:
        - Input: :math:`(N, *)` where `*` means, any number of additional dimensions
        - Output: :math:`(N, *)`, same shape as the input


    Examples::

        >>> import torch
        >>> from monai.networks.layers.factories import Act
        >>> m = Act['mish']()
        >>> input = torch.randn(2)
        >>> output = m(input)
    """

    def __init__(self, inplace: bool = False):
        super().__init__()
        # inplace only works when using torch.nn.functional.mish
        self.inplace = inplace

    def forward(self, input: torch.Tensor):
        return monai_mish(input, self.inplace)


class GEGLU(nn.Module):
    r"""Applies the element-wise function:

    .. math::
        \text{GEGLU}(x) = x_1 * \text{Sigmoid}(x_2)

    where :math:`x_1` and :math:`x_2` are split from the input tensor along the last dimension.

    Citation: GLU Variants Improve Transformer, Noam Shazeer, 2020, https://arxiv.org/abs/2002.05202.

    Shape:
        - Input: :math:`(N, *, 2 * D)`
        - Output: :math:`(N, *, D)`, where `*` means, any number of additional dimensions
    """

    def forward(self, input: torch.Tensor):
        x, gate = input.chunk(2, dim=-1)
        return x * nn.functional.gelu(gate)