LARRES / model_LARRES.py
Staty's picture
Upload 50 files
2b21abc verified
import torch
from torch import nn
from modules import ConvSC, Inception
from utilpack import (ConvNeXtSubBlock, ConvMixerSubBlock, GASubBlock, gInception_ST,
HorNetSubBlock, MLPMixerSubBlock, MogaSubBlock, PoolFormerSubBlock,
SwinSubBlock, UniformerSubBlock, VANSubBlock, ViTSubBlock, TAUSubBlock)
def stride_generator(N, reverse=False):
strides = [1, 2]*10
if reverse: return list(reversed(strides[:N]))
else: return strides[:N]
class Encoder(nn.Module):
def __init__(self,C_in, C_hid, N_S):
super(Encoder,self).__init__()
strides = stride_generator(N_S)
self.enc = nn.Sequential(
ConvSC(C_in, C_hid, stride=strides[0]),
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
)
def forward(self,x):# B*4, 3, 128, 128
enc1 = self.enc[0](x)
latent = enc1
for i in range(1,len(self.enc)):
latent = self.enc[i](latent)
return latent,enc1
class Decoder(nn.Module):
def __init__(self,C_hid, C_out, N_S):
super(Decoder,self).__init__()
strides = stride_generator(N_S, reverse=True)
self.dec = nn.Sequential(
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
)
self.readout = nn.Conv2d(C_hid, C_out, 1)
def forward(self, hid, enc1=None):
for i in range(0,len(self.dec)-1):
hid = self.dec[i](hid)
Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
Y = self.readout(Y)
return Y
class Mid_Xnet(nn.Module):
def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
super(Mid_Xnet, self).__init__()
self.N_T = N_T
enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))
self.enc = nn.Sequential(*enc_layers)
self.dec = nn.Sequential(*dec_layers)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B, T*C, H, W)
# encoder
skips = []
z = x
for i in range(self.N_T):
z = self.enc[i](z)
if i < self.N_T - 1:
skips.append(z)
# decoder
z = self.dec[0](z)
for i in range(1, self.N_T):
z = self.dec[i](torch.cat([z, skips[-i]], dim=1))
y = z.reshape(B, T, C, H, W)
return y
class MetaBlock(nn.Module):
"""The hidden Translator of MetaFormer for SimVP"""
def __init__(self, in_channels, out_channels, input_resolution=None, model_type=None,
mlp_ratio=8., drop=0.0, drop_path=0.0, layer_i=0):
super(MetaBlock, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
model_type = model_type.lower() if model_type is not None else 'gsta'
if model_type == 'gsta':
self.block = GASubBlock(
in_channels, kernel_size=21, mlp_ratio=mlp_ratio,
drop=drop, drop_path=drop_path, act_layer=nn.GELU)
elif model_type == 'convmixer':
self.block = ConvMixerSubBlock(in_channels, kernel_size=11, activation=nn.GELU)
elif model_type == 'convnext':
self.block = ConvNeXtSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
elif model_type == 'hornet':
self.block = HorNetSubBlock(in_channels, mlp_ratio=mlp_ratio, drop_path=drop_path)
elif model_type in ['mlp', 'mlpmixer']:
self.block = MLPMixerSubBlock(
in_channels, input_resolution, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
elif model_type in ['moga', 'moganet']:
self.block = MogaSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop_rate=drop, drop_path_rate=drop_path)
elif model_type == 'poolformer':
self.block = PoolFormerSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
elif model_type == 'swin':
self.block = SwinSubBlock(
in_channels, input_resolution, layer_i=layer_i, mlp_ratio=mlp_ratio,
drop=drop, drop_path=drop_path)
elif model_type == 'uniformer':
block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv'
self.block = UniformerSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop=drop,
drop_path=drop_path, block_type=block_type)
elif model_type == 'van':
self.block = VANSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path, act_layer=nn.GELU)
elif model_type == 'vit':
self.block = ViTSubBlock(
in_channels, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
else:
assert False and "Invalid model_type in SimVP"
if in_channels != out_channels:
self.reduction = nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, x):
z = self.block(x)
return z if self.in_channels == self.out_channels else self.reduction(z)
class MidMetaNet(nn.Module):
"""The hidden Translator of MetaFormer for SimVP"""
def __init__(self, channel_in, channel_hid, N2,
input_resolution=None, model_type=None,
mlp_ratio=4., drop=0.0, drop_path=0.1):
super(MidMetaNet, self).__init__()
assert N2 >= 2 and mlp_ratio > 1
self.N2 = N2
dpr = [ # stochastic depth decay rule
x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]
# downsample
enc_layers = [MetaBlock(
channel_in, channel_hid, input_resolution, model_type,
mlp_ratio, drop, drop_path=dpr[0], layer_i=0)]
# middle layers
for i in range(1, N2-1):
enc_layers.append(MetaBlock(
channel_hid, channel_hid, input_resolution, model_type,
mlp_ratio, drop, drop_path=dpr[i], layer_i=i))
# upsample
enc_layers.append(MetaBlock(
channel_hid, channel_in, input_resolution, model_type,
mlp_ratio, drop, drop_path=drop_path, layer_i=N2-1))
self.enc = nn.Sequential(*enc_layers)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B, T*C, H, W)
z = x
for i in range(self.N2):
z = self.enc[i](z)
y = z.reshape(B, T, C, H, W)
return y
class SimVP(nn.Module):
def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
super(SimVP, self).__init__()
T, C, H, W = 36,1,72,72
self.enc = Encoder(C, hid_S, N_S)
self.hid = MidMetaNet(T * hid_S, hid_T, N_T,
input_resolution=(H, W), model_type="vit",
mlp_ratio=8, drop=0.0, drop_path=0.1)
self.dec = Decoder(hid_S, C, N_S)
def forward(self, x_raw):
B, T, C, H, W = x_raw.shape
x = x_raw.view(B*T, C, H, W)
embed, skip = self.enc(x)
_, C_, H_, W_ = embed.shape
z = embed.view(B, T, C_, H_, W_)
hid = self.hid(z)
hid = hid.reshape(B*T, C_, H_, W_)
Y = self.dec(hid, skip)
Y = Y.reshape(B, T, C, H, W)
return Y
class larres(nn.Module):
def __init__(self, hid_S=32, hid_T=256, N_S=2, N_T=8, incep_ker=[3,5,7,11], groups=4):
super(larres, self).__init__()
T, C, H, W = 36,1,72,72
self.enc = Encoder(C, hid_S, N_S)
self.hid = Mid_Xnet(T * hid_S, hid_T, N_T, incep_ker, groups)
self.dec = Decoder(hid_S, C, N_S)
def forward(self, x_raw):
B, T, C, H, W = x_raw.shape
x = x_raw.view(B*T, C, H, W)
embed, skip = self.enc(x)
_, C_, H_, W_ = embed.shape
z = embed.view(B, T, C_, H_, W_)
hid = self.hid(z)
hid = hid.reshape(B*T, C_, H_, W_)
Y = self.dec(hid, skip)
Y = Y.reshape(B, T, C, H, W)
return Y