# refer to the code from PoolFormer, Thanks! # https://github.com/sail-sg/poolformer/blob/main/models/poolformer.py import torch import torch.nn as nn from timm.layers import DropPath, trunc_normal_ class GroupNorm(nn.GroupNorm): """ Group Normalization with 1 group. Input: tensor in shape [B, C, H, W] """ def __init__(self, num_channels, **kwargs): super().__init__(1, num_channels, **kwargs) class Pooling(nn.Module): """ Implementation of pooling for PoolFormer --pool_size: pooling size """ def __init__(self, pool_size=3): super().__init__() self.pool = nn.AvgPool2d( pool_size, stride=1, padding=pool_size//2, count_include_pad=False) def forward(self, x): return self.pool(x) - x class Mlp(nn.Module): """ Implementation of MLP with 1*1 convolutions. Input: tensor with shape [B, C, H, W] """ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Conv2d(in_features, hidden_features, 1) self.act = act_layer() self.fc2 = nn.Conv2d(hidden_features, out_features, 1) self.drop = nn.Dropout(drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Conv2d): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class PoolFormerBlock(nn.Module): """ Implementation of one PoolFormer block. --dim: embedding dim --pool_size: pooling size --mlp_ratio: mlp expansion ratio --act_layer: activation --norm_layer: normalization --drop: dropout rate --drop path: Stochastic Depth, refer to https://arxiv.org/abs/1603.09382 --init_value: LayerScale, refer to https://arxiv.org/abs/2103.17239 """ def __init__(self, dim, pool_size=3, mlp_ratio=4., drop=0., drop_path=0., init_value=1e-5, act_layer=nn.GELU, norm_layer=GroupNorm): super().__init__() self.norm1 = norm_layer(dim) self.token_mixer = Pooling(pool_size=pool_size) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) # The following two techniques are useful to train deep PoolFormers. self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) def forward(self, x): x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) return x