| """ |
| Applies the mish function element-wise: |
| mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
| """ |
|
|
| |
| import torch |
| import torch.nn.functional as F |
| from torch import nn |
|
|
| @torch.jit.script |
| def mish(input): |
| """ |
| Applies the mish function element-wise: |
| mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
| See additional documentation for mish class. |
| """ |
| return input * torch.tanh(F.softplus(input)) |
|
|
| class Mish(nn.Module): |
| """ |
| Applies the mish function element-wise: |
| mish(x) = x * tanh(softplus(x)) = x * tanh(ln(1 + exp(x))) |
| |
| Shape: |
| - Input: (N, *) where * means, any number of additional |
| dimensions |
| - Output: (N, *), same shape as the input |
| |
| Examples: |
| >>> m = Mish() |
| >>> input = torch.randn(2) |
| >>> output = m(input) |
| |
| Reference: https://pytorch.org/docs/stable/generated/torch.nn.Mish.html |
| """ |
|
|
| def __init__(self): |
| """ |
| Init method. |
| """ |
| super().__init__() |
|
|
| def forward(self, input): |
| """ |
| Forward pass of the function. |
| """ |
| if torch.__version__ >= "1.9": |
| return F.mish(input) |
| else: |
| return mish(input) |