| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from timm.layers import DropPath, trunc_normal_ |
|
|
|
|
| class GroupNorm(nn.GroupNorm): |
| """ |
| Group Normalization with 1 group. |
| Input: tensor in shape [B, C, H, W] |
| """ |
| def __init__(self, num_channels, **kwargs): |
| super().__init__(1, num_channels, **kwargs) |
|
|
|
|
| class Pooling(nn.Module): |
| """ |
| Implementation of pooling for PoolFormer |
| --pool_size: pooling size |
| """ |
| def __init__(self, pool_size=3): |
| super().__init__() |
| self.pool = nn.AvgPool2d( |
| pool_size, stride=1, padding=pool_size//2, count_include_pad=False) |
|
|
| def forward(self, x): |
| return self.pool(x) - x |
|
|
|
|
| class Mlp(nn.Module): |
| """ |
| Implementation of MLP with 1*1 convolutions. |
| Input: tensor with shape [B, C, H, W] |
| """ |
| def __init__(self, in_features, hidden_features=None, |
| out_features=None, act_layer=nn.GELU, drop=0.): |
| super().__init__() |
| out_features = out_features or in_features |
| hidden_features = hidden_features or in_features |
| self.fc1 = nn.Conv2d(in_features, hidden_features, 1) |
| self.act = act_layer() |
| self.fc2 = nn.Conv2d(hidden_features, out_features, 1) |
| self.drop = nn.Dropout(drop) |
| self.apply(self._init_weights) |
|
|
| def _init_weights(self, m): |
| if isinstance(m, nn.Conv2d): |
| trunc_normal_(m.weight, std=.02) |
| if m.bias is not None: |
| nn.init.constant_(m.bias, 0) |
|
|
| def forward(self, x): |
| x = self.fc1(x) |
| x = self.act(x) |
| x = self.drop(x) |
| x = self.fc2(x) |
| x = self.drop(x) |
| return x |
|
|
|
|
| class PoolFormerBlock(nn.Module): |
| """ |
| Implementation of one PoolFormer block. |
| --dim: embedding dim |
| --pool_size: pooling size |
| --mlp_ratio: mlp expansion ratio |
| --act_layer: activation |
| --norm_layer: normalization |
| --drop: dropout rate |
| --drop path: Stochastic Depth, |
| refer to https://arxiv.org/abs/1603.09382 |
| --init_value: LayerScale, |
| refer to https://arxiv.org/abs/2103.17239 |
| """ |
| def __init__(self, dim, pool_size=3, mlp_ratio=4., drop=0., drop_path=0., |
| init_value=1e-5, act_layer=nn.GELU, norm_layer=GroupNorm): |
| super().__init__() |
|
|
| self.norm1 = norm_layer(dim) |
| self.token_mixer = Pooling(pool_size=pool_size) |
| self.norm2 = norm_layer(dim) |
| mlp_hidden_dim = int(dim * mlp_ratio) |
| self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, |
| act_layer=act_layer, drop=drop) |
| |
| self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() |
| self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
| self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) |
|
|
| def forward(self, x): |
| x = x + self.drop_path( |
| self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.token_mixer(self.norm1(x))) |
| x = x + self.drop_path( |
| self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) |
| return x |
|
|