| """ MLP module w/ dropout and configurable activation layer |
| |
| Hacked together by / Copyright 2020 Ross Wightman |
| """ |
| from functools import partial |
|
|
| from torch import nn as nn |
|
|
| from .grn import GlobalResponseNorm |
| from .helpers import to_2tuple |
|
|
|
|
| class Mlp(nn.Module): |
| """ MLP as used in Vision Transformer, MLP-Mixer and related networks |
| |
| NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| norm_layer=None, |
| bias=True, |
| drop=0., |
| use_conv=False, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
| linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
| self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
| self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop1(x) |
| x = self.norm(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|
|
|
| class GluMlp(nn.Module): |
| """ MLP w/ GLU style gating |
| See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 |
| |
| NOTE: When use_conv=True, expects 2D NCHW tensors, otherwise N*C expected. |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.Sigmoid, |
| norm_layer=None, |
| bias=True, |
| drop=0., |
| use_conv=False, |
| gate_last=True, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| assert hidden_features % 2 == 0 |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
| linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
| self.chunk_dim = 1 if use_conv else -1 |
| self.gate_last = gate_last |
|
|
| self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() |
| self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def init_weights(self): |
| |
| if self.fc1.bias is not None: |
| nn.init.ones_(self.fc1.bias[self.fc1.bias.shape[0] // 2:]) |
| nn.init.normal_(self.fc1.weight[self.fc1.weight.shape[0] // 2:], std=1e-6) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x1, x2 = x.chunk(2, dim=self.chunk_dim) |
| x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 |
| x = self.drop1(x) |
| x = self.norm(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|
|
|
| SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) |
|
|
|
|
| class SwiGLU(nn.Module): |
| """ SwiGLU |
| NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and |
| better matches some other common impl which makes mapping checkpoints simpler. |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.SiLU, |
| norm_layer=None, |
| bias=True, |
| drop=0., |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
|
|
| self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) |
| self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def init_weights(self): |
| |
| if self.fc1_g.bias is not None: |
| nn.init.ones_(self.fc1_g.bias) |
| nn.init.normal_(self.fc1_g.weight, std=1e-6) |
|
|
| def forward(self, x): |
| x_gate = self.fc1_g(x) |
| x = self.fc1_x(x) |
| x = self.act(x_gate) * x |
| x = self.drop1(x) |
| x = self.norm(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|
|
|
| class GatedMlp(nn.Module): |
| """ MLP as used in gMLP |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| norm_layer=None, |
| gate_layer=None, |
| bias=True, |
| drop=0., |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
|
|
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| if gate_layer is not None: |
| assert hidden_features % 2 == 0 |
| self.gate = gate_layer(hidden_features) |
| hidden_features = hidden_features // 2 |
| else: |
| self.gate = nn.Identity() |
| self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop1(x) |
| x = self.gate(x) |
| x = self.norm(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|
|
|
| class ConvMlp(nn.Module): |
| """ MLP using 1x1 convs that keeps spatial dims (for 2D NCHW tensors) |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.ReLU, |
| norm_layer=None, |
| bias=True, |
| drop=0., |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
|
|
| self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) |
| self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() |
| self.act = act_layer() |
| self.drop = nn.Dropout(drop) |
| self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.norm(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| return x |
|
|
|
|
| class GlobalResponseNormMlp(nn.Module): |
| """ MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d |
| |
| NOTE: Intended for '2D' NCHW (use_conv=True) or NHWC (use_conv=False, channels-last) tensor layouts |
| """ |
| def __init__( |
| self, |
| in_features, |
| hidden_features=None, |
| out_features=None, |
| act_layer=nn.GELU, |
| bias=True, |
| drop=0., |
| use_conv=False, |
| ): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| bias = to_2tuple(bias) |
| drop_probs = to_2tuple(drop) |
| linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear |
|
|
| self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) |
| self.act = act_layer() |
| self.drop1 = nn.Dropout(drop_probs[0]) |
| self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) |
| self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) |
| self.drop2 = nn.Dropout(drop_probs[1]) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop1(x) |
| x = self.grn(x) |
| x = self.fc2(x) |
| x = self.drop2(x) |
| return x |
|
|