import math import torch import torch.nn as nn from timm.layers import DropPath, trunc_normal_ from timm.models.convnext import ConvNeXtBlock from timm.models.mlp_mixer import MixerBlock from timm.models.swin_transformer import SwinTransformerBlock, window_partition, window_reverse from timm.models.vision_transformer import Block as ViTBlock from .layers import (HorBlock, ChannelAggregationFFN, MultiOrderGatedAggregation, PoolFormerBlock, CBlock, SABlock, MixMlp, VANBlock) class BasicConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, dilation=1, upsampling=False, act_norm=False, act_inplace=True): super(BasicConv2d, self).__init__() self.act_norm = act_norm if upsampling is True: self.conv = nn.Sequential(*[ nn.Conv2d(in_channels, out_channels*4, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation), nn.PixelShuffle(2) ]) else: self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation) self.norm = nn.GroupNorm(2, out_channels) self.act = nn.SiLU(inplace=act_inplace) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.Conv2d)): trunc_normal_(m.weight, std=.02) nn.init.constant_(m.bias, 0) def forward(self, x): y = self.conv(x) if self.act_norm: y = self.act(self.norm(y)) return y class ConvSC(nn.Module): def __init__(self, C_in, C_out, kernel_size=3, downsampling=False, upsampling=False, act_norm=True, act_inplace=True): super(ConvSC, self).__init__() stride = 2 if downsampling is True else 1 padding = (kernel_size - stride + 1) // 2 self.conv = BasicConv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, upsampling=upsampling, padding=padding, act_norm=act_norm, act_inplace=act_inplace) def forward(self, x): y = self.conv(x) return y class GroupConv2d(nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=0, groups=1, act_norm=False, act_inplace=True): super(GroupConv2d, self).__init__() self.act_norm=act_norm if in_channels % groups != 0: groups=1 self.conv = nn.Conv2d( in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, groups=groups) self.norm = nn.GroupNorm(groups,out_channels) self.activate = nn.LeakyReLU(0.2, inplace=act_inplace) def forward(self, x): y = self.conv(x) if self.act_norm: y = self.activate(self.norm(y)) return y class gInception_ST(nn.Module): """A IncepU block for SimVP""" def __init__(self, C_in, C_hid, C_out, incep_ker = [3,5,7,11], groups = 8): super(gInception_ST, self).__init__() self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0) layers = [] for ker in incep_ker: layers.append(GroupConv2d( C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True)) self.layers = nn.Sequential(*layers) def forward(self, x): x = self.conv1(x) y = 0 for layer in self.layers: y += layer(x) return y class AttentionModule(nn.Module): """Large Kernel Attention for SimVP""" def __init__(self, dim, kernel_size, dilation=3): super().__init__() d_k = 2 * dilation - 1 d_p = (d_k - 1) // 2 dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1) dd_p = (dilation * (dd_k - 1) // 2) self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim) self.conv_spatial = nn.Conv2d( dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation) self.conv1 = nn.Conv2d(dim, 2*dim, 1) def forward(self, x): u = x.clone() attn = self.conv0(x) # depth-wise conv attn = self.conv_spatial(attn) # depth-wise dilation convolution f_g = self.conv1(attn) split_dim = f_g.shape[1] // 2 f_x, g_x = torch.split(f_g, split_dim, dim=1) return torch.sigmoid(g_x) * f_x class SpatialAttention(nn.Module): """A Spatial Attention block for SimVP""" def __init__(self, d_model, kernel_size=21, attn_shortcut=True): super().__init__() self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv self.activation = nn.GELU() # GELU self.spatial_gating_unit = AttentionModule(d_model, kernel_size) self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv self.attn_shortcut = attn_shortcut def forward(self, x): if self.attn_shortcut: shortcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) x = self.proj_2(x) if self.attn_shortcut: x = x + shortcut return x class GASubBlock(nn.Module): """A GABlock (gSTA) for SimVP""" def __init__(self, dim, kernel_size=21, mlp_ratio=4., drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU): super().__init__() self.norm1 = nn.BatchNorm2d(dim) self.attn = SpatialAttention(dim, kernel_size) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = nn.BatchNorm2d(dim) mlp_hidden_dim = int(dim * mlp_ratio) self.mlp = MixMlp( in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) self.layer_scale_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) self.layer_scale_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @torch.jit.ignore def no_weight_decay(self): return {'layer_scale_1', 'layer_scale_2'} def forward(self, x): x = x + self.drop_path( self.layer_scale_1.unsqueeze(-1).unsqueeze(-1) * self.attn(self.norm1(x))) x = x + self.drop_path( self.layer_scale_2.unsqueeze(-1).unsqueeze(-1) * self.mlp(self.norm2(x))) return x class ConvMixerSubBlock(nn.Module): """A block of ConvMixer.""" def __init__(self, dim, kernel_size=9, activation=nn.GELU): super().__init__() # spatial mixing self.conv_dw = nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same") self.act_1 = activation() self.norm_1 = nn.BatchNorm2d(dim) # channel mixing self.conv_pw = nn.Conv2d(dim, dim, kernel_size=1) self.act_2 = activation() self.norm_2 = nn.BatchNorm2d(dim) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @torch.jit.ignore def no_weight_decay(self): return dict() def forward(self, x): x = x + self.norm_1(self.act_1(self.conv_dw(x))) x = self.norm_2(self.act_2(self.conv_pw(x))) return x class ConvNeXtSubBlock(ConvNeXtBlock): """A block of ConvNeXt.""" def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): super().__init__(dim, mlp_ratio=mlp_ratio, drop_path=drop_path, ls_init_value=1e-6, conv_mlp=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @torch.jit.ignore def no_weight_decay(self): return {'gamma'} def forward(self, x): x = x + self.drop_path( self.gamma.reshape(1, -1, 1, 1) * self.mlp(self.norm(self.conv_dw(x)))) return x class HorNetSubBlock(HorBlock): """A block of HorNet.""" def __init__(self, dim, mlp_ratio=4., drop_path=0.1, init_value=1e-6): super().__init__(dim, mlp_ratio=mlp_ratio, drop_path=drop_path, init_value=init_value) self.apply(self._init_weights) @torch.jit.ignore def no_weight_decay(self): return {'gamma1', 'gamma2'} def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() class MLPMixerSubBlock(MixerBlock): """A block of MLP-Mixer.""" def __init__(self, dim, input_resolution=None, mlp_ratio=4., drop=0., drop_path=0.1): seq_len = input_resolution[0] * input_resolution[1] super().__init__(dim, seq_len=seq_len, mlp_ratio=(0.5, mlp_ratio), drop_path=drop_path, drop=drop) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return dict() def forward(self, x): B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = x + self.drop_path(self.mlp_tokens(self.norm1(x).transpose(1, 2)).transpose(1, 2)) x = x + self.drop_path(self.mlp_channels(self.norm2(x))) return x.reshape(B, H, W, C).permute(0, 3, 1, 2) class MogaSubBlock(nn.Module): """A block of MogaNet.""" def __init__(self, embed_dims, mlp_ratio=4., drop_rate=0., drop_path_rate=0., init_value=1e-5, attn_dw_dilation=[1, 2, 3], attn_channel_split=[1, 3, 4]): super(MogaSubBlock, self).__init__() self.out_channels = embed_dims # spatial attention self.norm1 = nn.BatchNorm2d(embed_dims) self.attn = MultiOrderGatedAggregation( embed_dims, attn_dw_dilation=attn_dw_dilation, attn_channel_split=attn_channel_split) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity() # channel MLP self.norm2 = nn.BatchNorm2d(embed_dims) mlp_hidden_dims = int(embed_dims * mlp_ratio) self.mlp = ChannelAggregationFFN( embed_dims=embed_dims, mlp_hidden_dims=mlp_hidden_dims, ffn_drop=drop_rate) # init layer scale self.layer_scale_1 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) self.layer_scale_2 = nn.Parameter(init_value * torch.ones((1, embed_dims, 1, 1)), requires_grad=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() @torch.jit.ignore def no_weight_decay(self): return {'layer_scale_1', 'layer_scale_2', 'sigma'} def forward(self, x): x = x + self.drop_path(self.layer_scale_1 * self.attn(self.norm1(x))) x = x + self.drop_path(self.layer_scale_2 * self.mlp(self.norm2(x))) return x class PoolFormerSubBlock(PoolFormerBlock): """A block of PoolFormer.""" def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): super().__init__(dim, pool_size=3, mlp_ratio=mlp_ratio, drop_path=drop_path, drop=drop, init_value=1e-5) self.apply(self._init_weights) @torch.jit.ignore def no_weight_decay(self): return {'layer_scale_1', 'layer_scale_2'} def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) class SwinSubBlock(SwinTransformerBlock): """A block of Swin Transformer.""" def __init__(self, dim, input_resolution=None, layer_i=0, mlp_ratio=4., drop=0., drop_path=0.1): window_size = 7 if input_resolution[0] % 7 == 0 else max(4, input_resolution[0] // 16) window_size = min(8, window_size) shift_size = 0 if (layer_i % 2 == 0) else window_size // 2 super().__init__(dim, input_resolution, num_heads=8, window_size=window_size, shift_size=shift_size, mlp_ratio=mlp_ratio, drop_path=drop_path, attn_drop=drop, proj_drop=drop, qkv_bias=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {} def forward(self, x): B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = self.norm1(x) x = x.view(B, H, W, C) x = super().forward(x) return x.reshape(B, H, W, C).permute(0, 3, 1, 2) def UniformerSubBlock(embed_dims, mlp_ratio=4., drop=0., drop_path=0., init_value=1e-6, block_type='Conv'): """Build a block of Uniformer.""" assert block_type in ['Conv', 'MHSA'] if block_type == 'Conv': return CBlock(dim=embed_dims, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) else: return SABlock(dim=embed_dims, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True, drop=drop, drop_path=drop_path, init_value=init_value) class VANSubBlock(VANBlock): """A block of VAN.""" def __init__(self, dim, mlp_ratio=4., drop=0.,drop_path=0., init_value=1e-2, act_layer=nn.GELU): super().__init__(dim=dim, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, init_value=init_value, act_layer=act_layer) self.apply(self._init_weights) @torch.jit.ignore def no_weight_decay(self): return {'layer_scale_1', 'layer_scale_2'} def _init_weights(self, m): if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) elif isinstance(m, nn.Conv2d): fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels fan_out //= m.groups m.weight.data.normal_(0, math.sqrt(2.0 / fan_out)) if m.bias is not None: m.bias.data.zero_() class ViTSubBlock(ViTBlock): """A block of Vision Transformer.""" def __init__(self, dim, mlp_ratio=4., drop=0., drop_path=0.1): super().__init__(dim=dim, num_heads=8, mlp_ratio=mlp_ratio, qkv_bias=True, attn_drop=drop, proj_drop=0, drop_path=drop_path, act_layer=nn.GELU, norm_layer=nn.LayerNorm) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) @torch.jit.ignore def no_weight_decay(self): return {} def forward(self, x): B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) x = x + self.drop_path(self.attn(self.norm1(x))) x = x + self.drop_path(self.mlp(self.norm2(x))) return x.reshape(B, H, W, C).permute(0, 3, 1, 2) class TemporalAttention(nn.Module): """A Temporal Attention block for Temporal Attention Unit""" def __init__(self, d_model, kernel_size=21, attn_shortcut=True): super().__init__() self.proj_1 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv self.activation = nn.GELU() # GELU self.spatial_gating_unit = TemporalAttentionModule(d_model, kernel_size) self.proj_2 = nn.Conv2d(d_model, d_model, 1) # 1x1 conv self.attn_shortcut = attn_shortcut def forward(self, x): if self.attn_shortcut: shortcut = x.clone() x = self.proj_1(x) x = self.activation(x) x = self.spatial_gating_unit(x) x = self.proj_2(x) if self.attn_shortcut: x = x + shortcut return x class TemporalAttentionModule(nn.Module): """Large Kernel Attention for SimVP""" def __init__(self, dim, kernel_size, dilation=3, reduction=16): super().__init__() d_k = 2 * dilation - 1 d_p = (d_k - 1) // 2 dd_k = kernel_size // dilation + ((kernel_size // dilation) % 2 - 1) dd_p = (dilation * (dd_k - 1) // 2) self.conv0 = nn.Conv2d(dim, dim, d_k, padding=d_p, groups=dim) self.conv_spatial = nn.Conv2d( dim, dim, dd_k, stride=1, padding=dd_p, groups=dim, dilation=dilation) self.conv1 = nn.Conv2d(dim, dim, 1) self.reduction = max(dim // reduction, 4) self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc = nn.Sequential( nn.Linear(dim, dim // self.reduction, bias=False), # reduction nn.ReLU(True), nn.Linear(dim // self.reduction, dim, bias=False), # expansion nn.Sigmoid() ) def forward(self, x): u = x.clone() attn = self.conv0(x) # depth-wise conv attn = self.conv_spatial(attn) # depth-wise dilation convolution f_x = self.conv1(attn) # 1x1 conv # append a se operation b, c, _, _ = x.size() se_atten = self.avg_pool(x).view(b, c) se_atten = self.fc(se_atten).view(b, c, 1, 1) return se_atten * f_x * u class TAUSubBlock(GASubBlock): """A TAUBlock (tau) for Temporal Attention Unit""" def __init__(self, dim, kernel_size=21, mlp_ratio=4., drop=0., drop_path=0.1, init_value=1e-2, act_layer=nn.GELU): super().__init__(dim=dim, kernel_size=kernel_size, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, init_value=init_value, act_layer=act_layer) self.attn = TemporalAttention(dim, kernel_size)