File size: 4,430 Bytes
6021dd1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import torch
from torch import nn
from stldm.submodules import ChannelConversion
from stldm.simvpv2 import stride_generator, ConvSC, MidMetaNet
class Encoder(nn.Module):
def __init__(self, C_in, C_hid, N_S):
super(Encoder, self).__init__()
strides = stride_generator(N_S)
self.enc = nn.Sequential(
ConvSC(C_in, C_hid, stride=strides[0]),
*[ConvSC(C_hid, C_hid, stride=s) for s in strides[1:]],
ChannelConversion(C_hid, 2*C_hid)
)
def forward(self, x):
for encoder in self.enc:
x = encoder(x)
(mean, log_var) = torch.chunk(x, 2, dim=1)
return mean, log_var
class Decoder(nn.Module):
def __init__(self, C_hid, C_out, N_S, last_activation='sigmoid'):
super(Decoder,self).__init__()
strides = stride_generator(N_S, reverse=True)
self.dec = nn.Sequential(
ChannelConversion(C_hid, C_hid),
*[ConvSC(C_hid, C_hid, stride=s, transpose=True) for s in strides[:-1]],
ConvSC(C_hid, C_hid, stride=strides[-1], transpose=True)# Modify HERE
)
self.readout = nn.Conv2d(C_hid, C_out, 1)
if last_activation=='sigmoid':
self.last = nn.Sigmoid()
else:
self.last = nn.Identity()
def forward(self, x):
for decoder in self.dec:
x = decoder(x)
Y = self.readout(x)
return self.last(Y)
class VAE(nn.Module):
def __init__(self, C_in, hid_S, N_S, last_activation='none'):
super(VAE, self).__init__()
self.encoder = Encoder(C_in, hid_S, N_S)
self.decoder = Decoder(hid_S, C_in, N_S, last_activation)
def sample_from_standard_normal(self, mean, log_var):
std = (0.5 * log_var).exp()
return mean + std * torch.randn_like(mean)
def encode(self, x):
assert x.ndim==4
mean, log_var = self.encoder(x)
return mean, log_var
def decode(self, z):
assert z.ndim==4
dec = self.decoder(z)
return dec
def kl_from_standard_normal(self, mean, log_var):
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
return kl.mean()
def _losses_(self, x, y):
mean, log_var = self.encode(x)
kl_loss = self.kl_from_standard_normal(mean, log_var)
y_pred = self.forward(x)
recon_loss = nn.MSELoss()(y_pred, y)
return recon_loss, kl_loss
def forward(self, x):
mu_z, log_var = self.encode(x)
z = self.sample_from_standard_normal(mu_z, log_var)
recon = self.decode(z)
return recon
class SimVPV2_Model(nn.Module):
def __init__(self, shape_in, shape_out, hid_S=16, hid_T=256, N_S=4, N_T=4,
mlp_ratio=8., drop=0.0, drop_path=0.0, spatio_kernel_enc=3,
spatio_kernel_dec=3, last_activation='none', act_inplace=True, **kwargs):
super(SimVPV2_Model, self).__init__()
T, C, H, W = shape_in # T is pre_seq_length
T2, C2, H2, W2 = shape_out # T2 is output length
assert C==C2 and H==H2 and W==W2, 'Need to be the same image shape for input and output'
self.T2 = T2
self.T = T
H, W = int(H / 2**(N_S/2)), int(W / 2**(N_S/2)) # downsample 1 / 2**(N_S/2)
self.vae = VAE(C_in=C, hid_S=hid_S, N_S=N_S, last_activation=last_activation)
self.hid = MidMetaNet(T*hid_S, T2*hid_S*2, hid_T, N_T,
input_resolution=(H, W), model_type='gsta',
mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)
def forward(self, x_raw):
B, T, C, H, W = x_raw.shape
x = x_raw.reshape(B*T, C, H, W)
embed, log_var = self.vae.encode(x)
embed = self.vae.sample_from_standard_normal(embed, log_var)
*_, C_, H_, W_ = embed.shape
z = embed.view(B, T, C_, H_, W_)
hid, *_ = self.hid(z)
hid_mu, log_var_hid = torch.chunk(hid, 2, dim=1)
hid = self.vae.sample_from_standard_normal(hid_mu, log_var_hid)
hid = hid.reshape(B*self.T2, C_, H_, W_)
# conds_ = hid
conds_ = hid_mu.reshape(B*self.T2, C_, H_, W_)
Y = self.vae.decode(hid)
Y = Y.reshape(B, self.T2, C, H, W)
return Y, conds_
def _losses_(self, x, y):
y_pred, *_ = self.forward(x)
recon_loss = nn.MSELoss()(y_pred, y)
return recon_loss |