Spaces:
Sleeping
Sleeping
| from functools import partial | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from timm.models.layers import trunc_normal_, DropPath | |
| from timm.models.registry import register_model | |
| from timm.models.layers.helpers import to_2tuple | |
| class LayerNormGeneral(nn.Module): | |
| r""" General LayerNorm for different situations. | |
| Args: | |
| affine_shape (int, list or tuple): The shape of affine weight and bias. | |
| Usually the affine_shape=C, but in some implementation, like torch.nn.LayerNorm, | |
| the affine_shape is the same as normalized_dim by default. | |
| To adapt to different situations, we offer this argument here. | |
| normalized_dim (tuple or list): Which dims to compute mean and variance. | |
| scale (bool): Flag indicates whether to use scale or not. | |
| bias (bool): Flag indicates whether to use scale or not. | |
| We give several examples to show how to specify the arguments. | |
| LayerNorm (https://arxiv.org/abs/1607.06450): | |
| For input shape of (B, *, C) like (B, N, C) or (B, H, W, C), | |
| affine_shape=C, normalized_dim=(-1, ), scale=True, bias=True; | |
| For input shape of (B, C, H, W), | |
| affine_shape=(C, 1, 1), normalized_dim=(1, ), scale=True, bias=True. | |
| Modified LayerNorm (https://arxiv.org/abs/2111.11418) | |
| that is idental to partial(torch.nn.GroupNorm, num_groups=1): | |
| For input shape of (B, N, C), | |
| affine_shape=C, normalized_dim=(1, 2), scale=True, bias=True; | |
| For input shape of (B, H, W, C), | |
| affine_shape=C, normalized_dim=(1, 2, 3), scale=True, bias=True; | |
| For input shape of (B, C, H, W), | |
| affine_shape=(C, 1, 1), normalized_dim=(1, 2, 3), scale=True, bias=True. | |
| For the several metaformer baslines, | |
| IdentityFormer, RandFormer and PoolFormerV2 utilize Modified LayerNorm without bias (bias=False); | |
| ConvFormer and CAFormer utilizes LayerNorm without bias (bias=False). | |
| """ | |
| def __init__(self, affine_shape=None, normalized_dim=(-1, ), scale=True, | |
| bias=False, eps=1e-6): | |
| super().__init__() | |
| self.normalized_dim = normalized_dim | |
| self.use_scale = scale | |
| self.use_bias = bias | |
| self.weight = nn.Parameter(torch.ones(affine_shape)) if scale else None | |
| self.bias = nn.Parameter(torch.zeros(affine_shape)) if bias else None | |
| self.eps = eps | |
| def forward(self, x): | |
| c = x - x.mean(self.normalized_dim, keepdim=True) | |
| s = c.pow(2).mean(self.normalized_dim, keepdim=True) | |
| x = c / torch.sqrt(s + self.eps) | |
| if self.use_scale: | |
| x = x * self.weight | |
| if self.use_bias: | |
| x = x + self.bias | |
| return x | |
| def stem(in_chs, out_chs, act_layer=nn.GELU): | |
| return nn.Sequential( | |
| nn.Conv2d(in_chs, out_chs // 2, kernel_size=3, stride=2, padding=1), | |
| ## nn.BatchNorm2d(out_chs // 2), | |
| nn.InstanceNorm2d(out_chs // 2), | |
| act_layer(), | |
| nn.Conv2d(out_chs // 2, out_chs, kernel_size=3, stride=2, padding=1), | |
| ## nn.BatchNorm2d(out_chs), | |
| nn.InstanceNorm2d(out_chs), | |
| act_layer(), | |
| ) | |
| class Downsampling(nn.Module): | |
| """ | |
| Downsampling implemented by a layer of convolution. | |
| """ | |
| def __init__(self, in_channels, out_channels, | |
| kernel_size=3, stride=2, padding=1, | |
| pre_norm=LayerNormGeneral, post_norm=None, pre_permute=True): | |
| super().__init__() | |
| self.pre_norm = pre_norm(in_channels) if pre_norm else nn.Identity() | |
| self.pre_permute = pre_permute | |
| self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, | |
| stride=stride, padding=padding) | |
| self.post_norm = post_norm( | |
| out_channels) if post_norm else nn.Identity() | |
| def forward(self, x): | |
| x = self.pre_norm(x) | |
| if self.pre_permute: | |
| x = x.permute(0, 3, 1, 2).contiguous() # if take [B, H, W, C] as input, permute it to [B, C, H, W] | |
| x = self.conv(x) | |
| x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C] | |
| x = self.post_norm(x) | |
| return x | |
| class Scale(nn.Module): | |
| """ | |
| Scale vector by element multiplications. | |
| """ | |
| def __init__(self, dim, init_value=1.0, trainable=True): | |
| super().__init__() | |
| self.scale = nn.Parameter( | |
| init_value * torch.ones(dim), requires_grad=trainable) | |
| def forward(self, x): | |
| return x * self.scale | |
| class LayerNormWithoutBias(nn.Module): | |
| """ | |
| Equal to partial(LayerNormGeneral, bias=False) but faster, | |
| because it directly utilizes otpimized F.layer_norm | |
| """ | |
| def __init__(self, normalized_shape, eps=1e-5, **kwargs): | |
| super().__init__() | |
| self.eps = eps | |
| self.bias = None | |
| if isinstance(normalized_shape, int): | |
| normalized_shape = (normalized_shape,) | |
| self.weight = nn.Parameter(torch.ones(normalized_shape)) | |
| self.normalized_shape = normalized_shape | |
| def forward(self, x): | |
| return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) | |
| class SepConv(nn.Module): | |
| r""" | |
| Inverted separable convolution from MobileNetV2: https://arxiv.org/abs/1801.04381. | |
| """ | |
| def __init__(self, dim, expansion_ratio=2, | |
| act1_layer=nn.GELU, act2_layer=nn.Identity, | |
| bias=False, kernel_size=3, padding=1, | |
| **kwargs, ): | |
| super().__init__() | |
| med_channels = int(expansion_ratio * dim) | |
| self.pwconv1 = nn.Linear(dim, med_channels, bias=bias) | |
| self.act1 = act1_layer() | |
| self.dwconv = nn.Conv2d( | |
| med_channels, med_channels, kernel_size=kernel_size, | |
| padding=padding, groups=med_channels, bias=bias) # depthwise conv | |
| self.act2 = act2_layer() | |
| self.pwconv2 = nn.Linear(med_channels, dim, bias=bias) | |
| def forward(self, x): | |
| x = self.pwconv1(x) | |
| x = self.act1(x) | |
| x = x.permute(0, 3, 1, 2) | |
| x = self.dwconv(x) | |
| x = x.permute(0, 2, 3, 1) | |
| x = self.act2(x) | |
| x = self.pwconv2(x) | |
| return x | |
| class Mlp(nn.Module): | |
| """ MLP as used in MetaFormer models, eg Transformer, MLP-Mixer, PoolFormer, MetaFormer baslines and related networks. | |
| Mostly copied from timm. | |
| """ | |
| def __init__(self, dim, mlp_ratio=4, out_features=None, act_layer=nn.GELU, drop=0., bias=False, **kwargs): | |
| super().__init__() | |
| in_features = dim | |
| out_features = out_features or in_features | |
| hidden_features = int(mlp_ratio * in_features) | |
| drop_probs = to_2tuple(drop) | |
| self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) | |
| self.act = act_layer() | |
| self.drop1 = nn.Dropout(drop_probs[0]) | |
| self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) | |
| self.drop2 = nn.Dropout(drop_probs[1]) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.act(x) | |
| x = self.drop1(x) | |
| x = self.fc2(x) | |
| x = self.drop2(x) | |
| return x | |
| class MetaFormerBlock(nn.Module): | |
| """ | |
| Implementation of one MetaFormer block. | |
| """ | |
| def __init__(self, dim, | |
| token_mixer=nn.Identity, mlp=Mlp, mlp_ratio=4, | |
| norm_layer=nn.LayerNorm, drop=0., drop_path=0., | |
| layer_scale_init_value=None, res_scale_init_value=None | |
| ): | |
| super().__init__() | |
| self.token_mixer = token_mixer(dim, drop=drop) | |
| self.drop_path1 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.norm1 = norm_layer(dim) | |
| self.layer_scale1 = Scale(dim=dim, init_value=layer_scale_init_value) \ | |
| if layer_scale_init_value else nn.Identity() | |
| self.res_scale1 = Scale(dim=dim, init_value=res_scale_init_value) \ | |
| if res_scale_init_value else nn.Identity() | |
| self.norm2 = norm_layer(dim) | |
| self.mlp = mlp(dim=dim, mlp_ratio=mlp_ratio, drop=drop) | |
| self.drop_path2 = DropPath(drop_path) if drop_path > 0. else nn.Identity() | |
| self.layer_scale2 = Scale(dim=dim, init_value=layer_scale_init_value) \ | |
| if layer_scale_init_value else nn.Identity() | |
| self.res_scale2 = Scale(dim=dim, init_value=res_scale_init_value) \ | |
| if res_scale_init_value else nn.Identity() | |
| def forward(self, x): | |
| x = x + self.drop_path1(self.token_mixer(self.norm1(x))) | |
| x = x + self.drop_path2(self.mlp(self.norm2(x))) | |
| return x | |
| class MetaFormer(nn.Module): | |
| r""" MetaFormer | |
| A PyTorch impl of : `MetaFormer Baselines for Vision` - | |
| https://arxiv.org/abs/2210.13452 | |
| Args: | |
| in_chans (int): Number of input image channels. Default: 3. | |
| num_classes (int): Number of classes for classification head. Default: 1000. | |
| depths (list or tuple): Number of blocks at each stage. Default: [2, 2, 6, 2]. | |
| dims (int): Feature dimension at each stage. Default: [64, 128, 320, 512]. | |
| downsample_layers: (list or tuple): Downsampling layers before each stage. | |
| token_mixers (list, tuple or token_fcn): Token mixer for each stage. Default: nn.Identity. | |
| mlps (list, tuple or mlp_fcn): Mlp for each stage. Default: Mlp. | |
| norm_layers (list, tuple or norm_fcn): Norm layers for each stage. Default: partial(LayerNormGeneral, eps=1e-6, bias=False). | |
| drop_path_rate (float): Stochastic depth rate. Default: 0. | |
| layer_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: None. | |
| None means not use the layer scale. Form: https://arxiv.org/abs/2103.17239. | |
| res_scale_init_values (list, tuple, float or None): Init value for Layer Scale. Default: [None, None, 1.0, 1.0]. | |
| None means not use the layer scale. From: https://arxiv.org/abs/2110.09456. | |
| head_fn: classification head. Default: nn.Linear. | |
| """ | |
| def __init__(self, in_chans=3, num_classes=1000, | |
| depths=[2, 2, 6, 2], | |
| dims=[64, 128, 320, 512], | |
| downsample_layers=[stem] + [Downsampling]*3, | |
| token_mixers=nn.Identity, | |
| mlps=Mlp, mlp_ratio=4, | |
| norm_layers=partial(LayerNormWithoutBias, eps=1e-6), | |
| drop_path_rate=0., | |
| layer_scale_init_values=None, | |
| res_scale_init_values=[None, None, 1.0, 1.0], | |
| head_fn=nn.Linear, | |
| **kwargs, | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| if not isinstance(depths, (list, tuple)): | |
| depths = [depths] # it means the model has only one stage | |
| if not isinstance(dims, (list, tuple)): | |
| dims = [dims] | |
| self.dims = dims | |
| self.depths = depths | |
| num_stage = len(depths) | |
| self.num_stage = num_stage | |
| down_dims = [in_chans] + dims | |
| self.downsample_layers = nn.ModuleList([downsample_layers[i](down_dims[i], down_dims[i+1]) for i in range(num_stage)]) | |
| if not isinstance(token_mixers, (list, tuple)): | |
| token_mixers = [token_mixers] * num_stage | |
| self.token_mixers = token_mixers | |
| if not isinstance(mlps, (list, tuple)): | |
| mlps = [mlps] * num_stage | |
| if not isinstance(norm_layers, (list, tuple)): | |
| norm_layers = [norm_layers] * num_stage | |
| dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] | |
| if not isinstance(layer_scale_init_values, (list, tuple)): | |
| layer_scale_init_values = [layer_scale_init_values] * num_stage | |
| if not isinstance(res_scale_init_values, (list, tuple)): | |
| res_scale_init_values = [res_scale_init_values] * num_stage | |
| self.stages = nn.ModuleList() # each stage consists of multiple metaformer blocks | |
| cur = 0 | |
| for i in range(num_stage): | |
| stage = nn.ModuleList( | |
| [MetaFormerBlock(dim=dims[i], token_mixer=token_mixers[i], | |
| mlp=mlps[i], mlp_ratio=mlp_ratio, norm_layer=norm_layers[i], | |
| drop_path=dp_rates[cur + j], | |
| layer_scale_init_value=layer_scale_init_values[i], | |
| res_scale_init_value=res_scale_init_values[i], | |
| ) for j in range(depths[i])] | |
| ) | |
| self.stages.append(stage) | |
| cur += depths[i] | |
| self.head = head_fn(dims[-1], num_classes) | |
| self.apply(self._init_weights) | |
| def _init_weights(self, m): | |
| if isinstance(m, (nn.Conv2d, nn.Linear)): | |
| trunc_normal_(m.weight, std=.02) | |
| if m.bias is not None: | |
| nn.init.constant_(m.bias, 0) | |
| def forward(self, x): | |
| outs = [] | |
| for i in range(self.num_stage): | |
| x = self.downsample_layers[i](x) | |
| if i==0: x = x.permute(0, 2, 3, 1).contiguous() # [B, C, H, W] -> [B, H, W, C] | |
| for j in range(self.depths[i]): | |
| x= self.stages[i][j](x) | |
| outs.append(x) # [B, H, W, C] | |
| return outs | |
| def convformer(variant='tiny'): | |
| if variant == 'tiny': | |
| model = convformer_t() | |
| elif variant == 'small': | |
| model = convformer_s() | |
| elif variant == 'base': | |
| model = convformer_b() | |
| elif variant == 'large': | |
| model = convformer_l() | |
| else: | |
| raise NotImplementedError | |
| return model | |
| def convformer_t(**kwargs): | |
| model = MetaFormer( | |
| depths=[2, 2, 6, 2], | |
| dims=[32, 64, 128, 160], | |
| mlps=Mlp, mlp_ratio=2, | |
| token_mixers=[SepConv, SepConv, SepConv, SepConv], | |
| head_fn=nn.Linear, | |
| **kwargs) | |
| return model | |
| def convformer_s(**kwargs): | |
| model = MetaFormer( | |
| depths=[2, 2, 6, 2], | |
| dims=[64, 128, 160, 320], | |
| mlps=Mlp, mlp_ratio=2, | |
| token_mixers=[SepConv, SepConv, SepConv, SepConv], | |
| head_fn=nn.Linear, | |
| **kwargs) | |
| return model | |
| def convformer_b(**kwargs): | |
| model = MetaFormer( | |
| depths=[2, 2, 6, 2], | |
| dims=[128, 256, 320, 512], | |
| mlps=Mlp, mlp_ratio=2, | |
| token_mixers=[SepConv, SepConv, SepConv, SepConv], | |
| head_fn=nn.Linear, | |
| **kwargs) | |
| return model | |
| def convformer_l(**kwargs): | |
| model = MetaFormer( | |
| depths=[2, 2, 6, 2], | |
| dims=[256, 384, 512, 768], | |
| mlps=Mlp, mlp_ratio=2, | |
| token_mixers=[SepConv, SepConv, SepConv, SepConv], | |
| head_fn=nn.Linear, | |
| **kwargs) | |
| return model |