|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from timm.layers import DropPath, trunc_normal_ |
|
|
|
|
|
|
|
|
class GroupNorm(nn.GroupNorm): |
|
|
""" |
|
|
Group Normalization with 1 group. |
|
|
Input: tensor in shape [B, C, H, W] |
|
|
""" |
|
|
def __init__(self, num_channels, **kwargs): |
|
|
super().__init__(1, num_channels, **kwargs) |
|
|
|
|
|
|
|
|
class Pooling(nn.Module): |
|
|
""" |
|
|
Implementation of pooling for PoolFormer |
|
|
--pool_size: pooling size |
|
|
""" |
|
|
def __init__(self, pool_size=3): |
|
|
super().__init__() |
|
|
self.pool = nn.AvgPool2d( |
|
|
pool_size, stride=1, padding=pool_size//2, count_include_pad=False) |
|
|
|
|
|
def forward(self, x): |
|
|
return self.pool(x) - x |
|
|
|
|
|
|
|
|
class Mlp(nn.Module): |
|
|
""" |
|
|
Implementation of MLP with 1*1 convolutions. |
|
|
Input: tensor with shape [B, C, H, W] |
|
|
""" |
|
|
def __init__(self, in_features, hidden_features=None, |
|
|
out_features=None, act_layer=nn.GELU, drop=0.): |
|
|
super().__init__() |
|
|
out_features = out_features or in_features |
|
|
hidden_features = hidden_features or in_features |
|
|
self.fc1 = nn.Conv2d(in_features, hidden_features, 1) |
|
|
self.act = act_layer() |
|
|
self.fc2 = nn.Conv2d(hidden_features, out_features, 1) |
|
|
self.drop = nn.Dropout(drop) |
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
trunc_normal_(m.weight, std=.02) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.fc1(x) |
|
|
x = self.act(x) |
|
|
x = self.drop(x) |
|
|
x = self.fc2(x) |
|
|
x = self.drop(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class PoolFormerBlock(nn.Module): |
|
|
""" |
|
|
Implementation of one PoolFormer block. |
|
|
--dim: embedding dim |
|
|
--pool_size: pooling size |
|
|
--mlp_ratio: mlp expansion ratio |
|
|
--act_layer: activation |
|
|
--norm_layer: normalization |
|
|
--drop: dropout rate |
|
|
--drop path: Stochastic Depth, |
|
|
refer to https://arxiv.org/abs/1603.09382 |
|
|
--init_value: LayerScale, |
|
|
refer to https://arxiv.org/abs/2103.17239 |
|
|
""" |
|
|
def __init__(self, dim, pool_size=3, mlp_ratio=4., drop=0., drop_path=0., |
|
|
init_value=1e-5, act_layer=nn.GELU, norm_layer=GroupNorm): |
|
|
super().__init__() |
|
|
|
|
|
self.norm1 = norm_layer(dim) |
|
|
self.token_mixer = Pooling(pool_size=pool_size) |
|
|
self.norm2 = norm_layer(dim) |
|
|
mlp_hidden_dim = int(dim * mlp_ratio) |
|
|
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, |
|
|
act_layer=act_layer, drop=drop) |
|
|
|
|
|
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
|
|
self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) |
|
|
x = x + self.drop_path( |
|
|
self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) |
|
|
return x |
|
|
|