from functools import partial import torch import torch.nn as nn import torch.nn.functional as F from timm.models.layers import trunc_normal_, DropPath from timm.models.registry import register_model from timm.models.layers.helpers import to_2tuple class LayerNormGeneral(nn.Module): r""" General LayerNorm for different situations. Args: affine_shape (int, list or tuple): The shape of affine weight and bias. Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, the affine_shape is the same as normalized_dim by default. To adapt to different situations, we offer this argument here. normalized_dim (tuple or list): Which dims to compute mean and variance. scale (bool): Flag indicates whether to use scale or not. bias (bool): Flag indicates whether to use scale or not. We give several examples to show how to specify the arguments. LayerNorm (https://arxiv.org/abs/1607.06450): For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; For input shape of (B, C, H, W), affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. Modified LayerNorm (https://arxiv.org/abs/2111.11418) that is idental to partial(torch.nn.GroupNorm, num_groups=1): For input shape of (B, N, C), affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; For input shape of (B, H, W, C), affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; For input shape of (B, C, H, W), affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. For the several metaformer baslines, IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). """ def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, bias=False, eps=1e-6): super().__init__() self.normalized_dim = normalized_dim self.use_scale = scale self.use_bias = bias self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None self.eps = eps def forward(self, x): c = x - x.mean(self.normalized_dim, keepdim=True) s = c.pow(2).mean(self.normalized_dim, keepdim=True) x = c / torch.sqrt(s + self.eps) if self.use_scale: x = x * self.weight if self.use_bias: x = x + self.bias return x def stem(in_chs, out_chs, act_layer=nn.GELU): return nn.Sequential( nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1), ## nn.BatchNorm2d(out_chs // 2), nn.InstanceNorm2d(out_chs // 2), act_layer(), nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1), ## nn.BatchNorm2d(out_chs), nn.InstanceNorm2d(out_chs), act_layer(), ) class Downsampling(nn.Module): """ Downsampling implemented by a layer of convolution. """ def __init__(self, in_channels, out_channels, kernel_size=3, stride=2, padding=1, pre_norm=LayerNormGeneral, post_norm=None, pre_permute=True): super().__init__() self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() self.pre_permute = pre_permute self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.post_norm = post_norm( out_channels) if post_norm else nn.Identity() def forward(self, x): x = self.pre_norm(x) if self.pre_permute: x = x.permute(0, 3, 1, 2).contiguous() # if take [B, H, W, C] as input, permute it to [B, C, H, W] x = self.conv(x) x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C] x = self.post_norm(x) return x class Scale(nn.Module): """ Scale vector by element multiplications. """ def __init__(self, dim, init_value=1.0, trainable=True): super().__init__() self.scale = nn.Parameter( init_value * torch.ones(dim), requires_grad=trainable) def forward(self, x): return x * self.scale class LayerNormWithoutBias(nn.Module): """ Equal to partial(LayerNormGeneral, bias=False) but faster, because it directly utilizes otpimized F.layer_norm """ def __init__(self, normalized_shape, eps=1e-5, **kwargs): super().__init__() self.eps = eps self.bias = None if isinstance(normalized_shape, int): normalized_shape = (normalized_shape,) self.weight = nn.Parameter(torch.ones(normalized_shape)) self.normalized_shape = normalized_shape def forward(self, x): return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) class SepConv(nn.Module): r""" Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. """ def __init__(self, dim, expansion_ratio=2, act1_layer=nn.GELU, act2_layer=nn.Identity, bias=False, kernel_size=3, padding=1, **kwargs, ): super().__init__() med_channels = int(expansion_ratio * dim) self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) self.act1 = act1_layer() self.dwconv = nn.Conv2d( med_channels, med_channels, kernel_size=kernel_size, padding=padding, groups=med_channels, bias=bias) # depthwise conv self.act2 = act2_layer() self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) def forward(self, x): x = self.pwconv1(x) x = self.act1(x) x = x.permute(0, 3, 1, 2) x = self.dwconv(x) x = x.permute(0, 2, 3, 1) x = self.act2(x) x = self.pwconv2(x) return x class Mlp(nn.Module): """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. Mostly copied from timm. """ def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=nn.GELU, drop=0., bias=False, **kwargs): super().__init__() in_features = dim out_features = out_features or in_features hidden_features = int(mlp_ratio * in_features) drop_probs = to_2tuple(drop) self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) self.act = act_layer() self.drop1 = nn.Dropout(drop_probs[0]) self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) self.drop2 = nn.Dropout(drop_probs[1]) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop1(x) x = self.fc2(x) x = self.drop2(x) return x class MetaFormerBlock(nn.Module): """ Implementation of one MetaFormer block. """ def __init__(self, dim, token_mixer=nn.Identity, mlp=Mlp, mlp_ratio=4, norm_layer=nn.LayerNorm, drop=0., drop_path=0., layer_scale_init_value=None, res_scale_init_value=None ): super().__init__() self.token_mixer = token_mixer(dim, drop=drop) self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm1 = norm_layer(dim) self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \ if layer_scale_init_value else nn.Identity() self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \ if res_scale_init_value else nn.Identity() self.norm2 = norm_layer(dim) self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop) self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ if layer_scale_init_value else nn.Identity() self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \ if res_scale_init_value else nn.Identity() def forward(self, x): x = x + self.drop_path1(self.token_mixer(self.norm1(x))) x = x + self.drop_path2(self.mlp(self.norm2(x))) return x class MetaFormer(nn.Module): r""" MetaFormer A PyTorch impl of : `MetaFormer Baselines for Vision` - https://arxiv.org/abs/2210.13452 Args: in_chans (int): Number of input image channels. Default: 3. num_classes (int): Number of classes for classification head. Default: 1000. depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. downsample_layers: (list or tuple): Downsampling layers before each stage. token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). drop_path_rate (float): Stochastic depth rate. Default: 0. layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. head_fn: classification head. Default: nn.Linear. """ def __init__(self, in_chans=3, num_classes=1000, depths=[2, 2, 6, 2], dims=[64, 128, 320, 512], downsample_layers=[stem] + [Downsampling]*3, token_mixers=nn.Identity, mlps=Mlp, mlp_ratio=4, norm_layers=partial(LayerNormWithoutBias, eps=1e-6), drop_path_rate=0., layer_scale_init_values=None, res_scale_init_values=[None, None, 1.0, 1.0], head_fn=nn.Linear, **kwargs, ): super().__init__() self.num_classes = num_classes if not isinstance(depths, (list, tuple)): depths = [depths] # it means the model has only one stage if not isinstance(dims, (list, tuple)): dims = [dims] self.dims = dims self.depths = depths num_stage = len(depths) self.num_stage = num_stage down_dims = [in_chans] + dims self.downsample_layers = nn.ModuleList([downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]) if not isinstance(token_mixers, (list, tuple)): token_mixers = [token_mixers] * num_stage self.token_mixers = token_mixers if not isinstance(mlps, (list, tuple)): mlps = [mlps] * num_stage if not isinstance(norm_layers, (list, tuple)): norm_layers = [norm_layers] * num_stage dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] if not isinstance(layer_scale_init_values, (list, tuple)): layer_scale_init_values = [layer_scale_init_values] * num_stage if not isinstance(res_scale_init_values, (list, tuple)): res_scale_init_values = [res_scale_init_values] * num_stage self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks cur = 0 for i in range(num_stage): stage = nn.ModuleList( [MetaFormerBlock(dim=dims[i], token_mixer=token_mixers[i], mlp=mlps[i], mlp_ratio=mlp_ratio, norm_layer=norm_layers[i], drop_path=dp_rates[cur + j], layer_scale_init_value=layer_scale_init_values[i], res_scale_init_value=res_scale_init_values[i], ) for j in range(depths[i])] ) self.stages.append(stage) cur += depths[i] self.head = head_fn(dims[-1], num_classes) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d, nn.Linear)): trunc_normal_(m.weight, std=.02) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x): outs = [] for i in range(self.num_stage): x = self.downsample_layers[i](x) if i==0: x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C] for j in range(self.depths[i]): x= self.stages[i][j](x) outs.append(x) # [B, H, W, C] return outs def convformer(variant='tiny'): if variant == 'tiny': model = convformer_t() elif variant == 'small': model = convformer_s() elif variant == 'base': model = convformer_b() elif variant == 'large': model = convformer_l() else: raise NotImplementedError return model @register_model def convformer_t(**kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[32, 64, 128, 160], mlps=Mlp, mlp_ratio=2, token_mixers=[SepConv, SepConv, SepConv, SepConv], head_fn=nn.Linear, **kwargs) return model @register_model def convformer_s(**kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[64, 128, 160, 320], mlps=Mlp, mlp_ratio=2, token_mixers=[SepConv, SepConv, SepConv, SepConv], head_fn=nn.Linear, **kwargs) return model @register_model def convformer_b(**kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[128, 256, 320, 512], mlps=Mlp, mlp_ratio=2, token_mixers=[SepConv, SepConv, SepConv, SepConv], head_fn=nn.Linear, **kwargs) return model @register_model def convformer_l(**kwargs): model = MetaFormer( depths=[2, 2, 6, 2], dims=[256, 384, 512, 768], mlps=Mlp, mlp_ratio=2, token_mixers=[SepConv, SepConv, SepConv, SepConv], head_fn=nn.Linear, **kwargs) return model