import torch from torch import nn from modules import ConvSC, Inception from utilpack import (ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST, HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock, SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock) def stride_generator(N, reverse=False): strides = [1, 2]*10 if reverse: return list(reversed(strides[:N])) else: return strides[:N] class Encoder(nn.Module): def __init__(self,C_in, C_hid, N_S): super(Encoder,self).__init__() strides = stride_generator(N_S) self.enc = nn.Sequential( ConvSC(C_in, C_hid, stride=strides[0]), *[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]] ) def forward(self,x):# B*4, 3, 128, 128 enc1 = self.enc[0](x) latent = enc1 for i in range(1,len(self.enc)): latent = self.enc[i](latent) return latent,enc1 class Decoder(nn.Module): def __init__(self,C_hid, C_out, N_S): super(Decoder,self).__init__() strides = stride_generator(N_S, reverse=True) self.dec = nn.Sequential( *[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]], ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True) ) self.readout = nn.Conv2d(C_hid, C_out, 1) def forward(self, hid, enc1=None): for i in range(0,len(self.dec)-1): hid = self.dec[i](hid) Y = self.dec[-1](torch.cat([hid, enc1], dim=1)) Y = self.readout(Y) return Y class Mid_Xnet(nn.Module): def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8): super(Mid_Xnet, self).__init__() self.N_T = N_T enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)] for i in range(1, N_T-1): enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)] for i in range(1, N_T-1): dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)) dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups)) self.enc = nn.Sequential(*enc_layers) self.dec = nn.Sequential(*dec_layers) def forward(self, x): B, T, C, H, W = x.shape x = x.reshape(B, T*C, H, W) # encoder skips = [] z = x for i in range(self.N_T): z = self.enc[i](z) if i < self.N_T - 1: skips.append(z) # decoder z = self.dec[0](z) for i in range(1, self.N_T): z = self.dec[i](torch.cat([z, skips[-i]], dim=1)) y = z.reshape(B, T, C, H, W) return y class MetaBlock(nn.Module): """The hidden Translator of MetaFormer for SimVP""" def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None, mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0): super(MetaBlock, self).__init__() self.in_channels = in_channels self.out_channels = out_channels model_type = model_type.lower() if model_type is not None else 'gsta' if model_type == 'gsta': self.block = GASubBlock( in_channels, kernel_size=21, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU) elif model_type == 'convmixer': self.block = ConvMixerSubBlock(in_channels, kernel_size=11, activation=nn.GELU) elif model_type == 'convnext': self.block = ConvNeXtSubBlock( in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) elif model_type == 'hornet': self.block = HorNetSubBlock(in_channels, mlp_ratio=mlp_ratio, drop_path=drop_path) elif model_type in ['mlp', 'mlpmixer']: self.block = MLPMixerSubBlock( in_channels, input_resolution, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) elif model_type in ['moga', 'moganet']: self.block = MogaSubBlock( in_channels, mlp_ratio=mlp_ratio, drop_rate=drop, drop_path_rate=drop_path) elif model_type == 'poolformer': self.block = PoolFormerSubBlock( in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) elif model_type == 'swin': self.block = SwinSubBlock( in_channels, input_resolution, layer_i=layer_i, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) elif model_type == 'uniformer': block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv' self.block = UniformerSubBlock( in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, block_type=block_type) elif model_type == 'van': self.block = VANSubBlock( in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU) elif model_type == 'vit': self.block = ViTSubBlock( in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path) else: assert False and "Invalid model_type in SimVP" if in_channels != out_channels: self.reduction = nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): z = self.block(x) return z if self.in_channels == self.out_channels else self.reduction(z) class MidMetaNet(nn.Module): """The hidden Translator of MetaFormer for SimVP""" def __init__(self, channel_in, channel_hid, N2, input_resolution=None, model_type=None, mlp_ratio=4., drop=0.0, drop_path=0.1): super(MidMetaNet, self).__init__() assert N2 >= 2 and mlp_ratio > 1 self.N2 = N2 dpr = [ # stochastic depth decay rule x.item() for x in torch.linspace(1e-2, drop_path, self.N2)] # downsample enc_layers = [MetaBlock( channel_in, channel_hid, input_resolution, model_type, mlp_ratio, drop, drop_path=dpr[0], layer_i=0)] # middle layers for i in range(1, N2-1): enc_layers.append(MetaBlock( channel_hid, channel_hid, input_resolution, model_type, mlp_ratio, drop, drop_path=dpr[i], layer_i=i)) # upsample enc_layers.append(MetaBlock( channel_hid, channel_in, input_resolution, model_type, mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1)) self.enc = nn.Sequential(*enc_layers) def forward(self, x): B, T, C, H, W = x.shape x = x.reshape(B, T*C, H, W) z = x for i in range(self.N2): z = self.enc[i](z) y = z.reshape(B, T, C, H, W) return y class SimVP(nn.Module): def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4): super(SimVP, self).__init__() T, C, H, W = 36,1,72,72 self.enc = Encoder(C, hid_S, N_S) self.hid = MidMetaNet(T * hid_S, hid_T, N_T, input_resolution=(H, W), model_type="vit", mlp_ratio=8, drop=0.0, drop_path=0.1) self.dec = Decoder(hid_S, C, N_S) def forward(self, x_raw): B, T, C, H, W = x_raw.shape x = x_raw.view(B*T, C, H, W) embed, skip = self.enc(x) _, C_, H_, W_ = embed.shape z = embed.view(B, T, C_, H_, W_) hid = self.hid(z) hid = hid.reshape(B*T, C_, H_, W_) Y = self.dec(hid, skip) Y = Y.reshape(B, T, C, H, W) return Y class larres(nn.Module): def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4): super(larres, self).__init__() T, C, H, W = 36,1,72,72 self.enc = Encoder(C, hid_S, N_S) self.hid = Mid_Xnet(T * hid_S, hid_T, N_T, incep_ker, groups) self.dec = Decoder(hid_S, C, N_S) def forward(self, x_raw): B, T, C, H, W = x_raw.shape x = x_raw.view(B*T, C, H, W) embed, skip = self.enc(x) _, C_, H_, W_ = embed.shape z = embed.view(B, T, C_, H_, W_) hid = self.hid(z) hid = hid.reshape(B*T, C_, H_, W_) Y = self.dec(hid, skip) Y = Y.reshape(B, T, C, H, W) return Y