easylearning's picture
Upload 91 files
a57e1d5 verified
from torch import nn
import torch
from torch import nn
class BasicConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, transpose=False, act_norm=False):
super(BasicConv2d, self).__init__()
self.act_norm=act_norm
if not transpose:
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
else:
self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,output_padding=stride //2 )
self.norm = nn.GroupNorm(2, out_channels)
self.act = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.act(self.norm(y))
return y
class ConvSC(nn.Module):
def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):
super(ConvSC, self).__init__()
if stride == 1:
transpose = False
self.conv = BasicConv2d(C_in, C_out, kernel_size=3, stride=stride,
padding=1, transpose=transpose, act_norm=act_norm)
def forward(self, x):
y = self.conv(x)
return y
class GroupConv2d(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, groups, act_norm=False):
super(GroupConv2d, self).__init__()
self.act_norm = act_norm
if in_channels % groups != 0:
groups = 1
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding,groups=groups)
self.norm = nn.GroupNorm(groups,out_channels)
self.activate = nn.LeakyReLU(0.2, inplace=True)
def forward(self, x):
y = self.conv(x)
if self.act_norm:
y = self.activate(self.norm(y))
return y
class Inception(nn.Module):
def __init__(self, C_in, C_hid, C_out, incep_ker=[3,5,7,11], groups=8):
super(Inception, self).__init__()
self.conv1 = nn.Conv2d(C_in, C_hid, kernel_size=1, stride=1, padding=0)
layers = []
for ker in incep_ker:
layers.append(GroupConv2d(C_hid, C_out, kernel_size=ker, stride=1, padding=ker//2, groups=groups, act_norm=True))
self.layers = nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
y = 0
for layer in self.layers:
y += layer(x)
return y
def stride_generator(N, reverse=False):
strides = [1, 2]*10
if reverse: return list(reversed(strides[:N]))
else: return strides[:N]
class Encoder(nn.Module):
def __init__(self,C_in, C_hid, N_S):
super(Encoder,self).__init__()
strides = stride_generator(N_S)
self.enc = nn.Sequential(
ConvSC(C_in, C_hid, stride=strides[0]),
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]]
)
def forward(self,x):# B*4, 3, 128, 128
enc1 = self.enc[0](x)
latent = enc1
for i in range(1,len(self.enc)):
latent = self.enc[i](latent)
return latent,enc1
class Decoder(nn.Module):
def __init__(self,C_hid, C_out, N_S):
super(Decoder,self).__init__()
strides = stride_generator(N_S, reverse=True)
self.dec = nn.Sequential(
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
ConvSC(2*C_hid, C_hid, stride=strides[-1], transpose=True)
)
self.readout = nn.Conv2d(C_hid, C_out, 1)
def forward(self, hid, enc1=None):
for i in range(0,len(self.dec)-1):
hid = self.dec[i](hid)
Y = self.dec[-1](torch.cat([hid, enc1], dim=1))
Y = self.readout(Y)
return Y
class Mid_Xnet(nn.Module):
def __init__(self, channel_in, channel_hid, N_T, incep_ker = [3,5,7,11], groups=8):
super(Mid_Xnet, self).__init__()
self.N_T = N_T
enc_layers = [Inception(channel_in, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
enc_layers.append(Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers = [Inception(channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups)]
for i in range(1, N_T-1):
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_hid, incep_ker= incep_ker, groups=groups))
dec_layers.append(Inception(2*channel_hid, channel_hid//2, channel_in, incep_ker= incep_ker, groups=groups))
self.enc = nn.Sequential(*enc_layers)
self.dec = nn.Sequential(*dec_layers)
def forward(self, x):
B, T, C, H, W = x.shape
x = x.reshape(B, T*C, H, W)
# encoder
skips = []
z = x
for i in range(self.N_T):
z = self.enc[i](z)
if i < self.N_T - 1:
skips.append(z)
# decoder
z = self.dec[0](z)
for i in range(1, self.N_T):
z = self.dec[i](torch.cat([z, skips[-i]], dim=1))
y = z.reshape(B, T, C, H, W)
return y
class SimVP(nn.Module):
def __init__(self, shape_in, hid_S=16, hid_T=256, N_S=4, N_T=8, output_dim = 1, incep_ker=[3,5,7,11], groups=8):
super(SimVP, self).__init__()
T, C, H, W = shape_in
self.output_dim = output_dim
self.enc = Encoder(C, hid_S, N_S)
self.hid = Mid_Xnet(T*hid_S, hid_T, N_T, incep_ker, groups)
self.dec = Decoder(hid_S, self.output_dim, N_S)
def forward(self, x_raw):
B, T, C, H, W = x_raw.shape
x = x_raw.view(B*T, C, H, W)
embed, skip = self.enc(x)
_, C_, H_, W_ = embed.shape
z = embed.view(B, T, C_, H_, W_)
hid = self.hid(z)
hid = hid.reshape(B*T, C_, H_, W_)
Y = self.dec(hid, skip)
Y = Y.reshape(B, T, -1, H, W)
return Y
if __name__ == "__main__":
inputs = torch.randn(1, 10, 2, 64, 448)
model = SimVP(shape_in=(10, 2, 64, 448), hid_S=32, hid_T=128, output_dim = 2)
outputs = model(inputs)
print(outputs.shape)