| | """ MLP module w/ dropout and configurable activation layer |
| | |
| | Hacked together by / Copyright 2020 Ross Wightman |
| | """ |
| | from torch import nn as nn |
| |
|
| |
|
| | class Mlp(nn.Module): |
| | """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| | """ |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class GluMlp(nn.Module): |
| | """ MLP w/ GLU style gating |
| | See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 |
| | """ |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.Sigmoid, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | assert hidden_features % 2 == 0 |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | self.fc2 = nn.Linear(hidden_features // 2, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def init_weights(self): |
| | |
| | fc1_mid = self.fc1.bias.shape[0] // 2 |
| | nn.init.ones_(self.fc1.bias[fc1_mid:]) |
| | nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x, gates = x.chunk(2, dim=-1) |
| | x = x * self.act(gates) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class GatedMlp(nn.Module): |
| | """ MLP as used in gMLP |
| | """ |
| | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, |
| | gate_layer=None, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Linear(in_features, hidden_features) |
| | self.act = act_layer() |
| | if gate_layer is not None: |
| | assert hidden_features % 2 == 0 |
| | self.gate = gate_layer(hidden_features) |
| | hidden_features = hidden_features // 2 |
| | else: |
| | self.gate = nn.Identity() |
| | self.fc2 = nn.Linear(hidden_features, out_features) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.gate(x) |
| | x = self.fc2(x) |
| | x = self.drop(x) |
| | return x |
| |
|
| |
|
| | class ConvMlp(nn.Module): |
| | """ MLP using 1x1 convs that keeps spatial dims |
| | """ |
| | def __init__( |
| | self, in_features, hidden_features=None, out_features=None, act_layer=nn.ReLU, norm_layer=None, drop=0.): |
| | super().__init__() |
| | out_features = out_features or in_features |
| | hidden_features = hidden_features or in_features |
| | self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=True) |
| | self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() |
| | self.act = act_layer() |
| | self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=True) |
| | self.drop = nn.Dropout(drop) |
| |
|
| | def forward(self, x): |
| | x = self.fc1(x) |
| | x = self.norm(x) |
| | x = self.act(x) |
| | x = self.drop(x) |
| | x = self.fc2(x) |
| | return x |
| |
|