STLDM_official / stldm /stldm_spatial.py
sqfoo's picture
Upload 99 files
6021dd1 verified
import torch, random
from torch import nn
from einops import rearrange
from stldm.submodules import *
class Down_Block(nn.Module):
def __init__(self, in_ch, hid_ch, out_ch, time_dim, is_last, patch_size=None, num_groups=8, heads=4, dim_head=32):
super(Down_Block, self).__init__()
self.block1 = ResnetBlock(dim=in_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
self.block2 = ResnetBlock(dim=hid_ch, dim_out=hid_ch, groups=num_groups)
self.attn2 = nn.Identity()
# self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head))) if patch_size is None else Residual(PreNorm(hid_ch, Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
self.last = Downsample2D(dim_in=hid_ch, dim_out=out_ch) if not is_last else ChannelConversion(hid_ch, out_ch)
def forward(self, x, time_emb, cond=None):
assert x.ndim==5
B, T, C, H, W = x.shape
x = x.reshape(B*T, C, H, W)
if cond is None:
cond = torch.zeros_like(x) # -> Unconditioning
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
time_emb = time_emb.repeat(1, T, 1)
time_emb = time_emb.reshape(B*T, -1)
out = torch.cat((x, cond), dim=1) # BT, 2C, H, W
out = self.block1(out, time_emb)
skip1 = self.attn1(out)
out = self.block2(skip1, time_emb)
skip2 = self.attn2(out)
out = self.last(skip2)
*_, c, h, w = out.shape
return out.reshape(B, T, c, h, w), skip1, skip2
class MidBlock(nn.Module):
def __init__(self, in_ch, time_dim, num_groups=8, heads=4, dim_head=32):
super(MidBlock, self).__init__()
self.block1 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
self.attn1 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
self.block2 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
self.attn2 = nn.Identity()
# self.attn2 = Residual(PreNorm(in_ch, Quadratic_SpatialAttention(dim=in_ch, heads=heads, dim_head=dim_head)))
self.block3 = ResnetBlock(dim=in_ch, dim_out=in_ch, time_emb_dim=time_dim, groups=num_groups)
def forward(self, x, time_emb, relative_pos=None):
assert x.ndim==5
B, T, C, H, W = x.shape
x = x.reshape(B*T, C, H, W)
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
time_emb = time_emb.repeat(1, T, 1)
time_emb = time_emb.reshape(B*T, -1)
out = self.block1(x, time_emb)
out = self.attn1(out)
out = self.block2(out, time_emb) # a little bit difference here
out = self.attn2(out)
out = self.block3(out, time_emb)
*_, c, h, w = out.shape
return out.reshape(B, T, c, h, w)
class Up_Block(nn.Module):
def __init__(self, in_chs, hid_ch, out_ch, is_last, time_dim, patch_size=None, num_groups=8, heads=4, dim_head=32):
super(Up_Block, self).__init__()
in_ch, skip_ch = in_chs
self.up = Upsample2D(dim_in=in_ch, dim_out=hid_ch) if not is_last else ChannelConversion(in_ch, hid_ch)
self.attn1 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
self.block1 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=hid_ch, time_emb_dim=time_dim, groups=num_groups)
self.attn2 = nn.Identity()
# self.attn2 = Residual(PreNorm(hid_ch, Quadratic_SpatialAttention(dim=hid_ch, heads=heads, dim_head=dim_head) if patch_size is None else Linear_SpatialAttention(dim=hid_ch, patch_size=patch_size, heads=heads, dim_head=dim_head)))
self.block2 = ResnetBlock(dim=hid_ch+skip_ch, dim_out=out_ch, time_emb_dim=time_dim, groups=num_groups)
def forward(self, x, time_emb, skip1, skip2):
assert x.ndim==5
B, T, C, H, W = x.shape
x = x.reshape(B*T, C, H, W)
time_emb = time_emb.unsqueeze(1) # From (B C) to (B 1 C)
time_emb = time_emb.repeat(1, T, 1)
time_emb = time_emb.reshape(B*T, -1)
out = self.up(x)
*_, c, h, w = out.shape
out = self.attn1(out)
out = torch.cat((out, skip2), dim=1)
out = self.block1(out, time_emb)
out = self.attn2(out)
out = torch.cat((out, skip2), dim=1)
out = self.block2(out, time_emb)
*_, c, h, w = out.shape
return out.reshape(B, T, c, h, w)
class LDM(nn.Module):
def __init__(self, in_ch, chs_mult:tuple, patch_size=None, num_groups=8, heads=4, dim_head=32, base_ch=64):
super(LDM, self).__init__()
# Time Embedding MLP
time_dim = 4*base_ch
fourier_dim = base_ch
self.time_mlp = Time_MLP(dim=base_ch, time_dim=time_dim, fourier_dim=fourier_dim)
ups, downs = [], []
conditions = []
layer_no = len(chs_mult)
chs = [in_ch, *map(lambda m: base_ch*m, chs_mult)]
ch_in, ch_out = chs[:-1], chs[1:]
up_in, up_out = list(reversed(ch_out)), list(reversed(ch_in))
patches = None if patch_size is None else [patch_size//(2**n) for n in range(layer_no)] # Patch Size should be 2^N
for n in range(layer_no):
downs.append(
Down_Block(in_ch=2*ch_in[n], hid_ch=ch_in[n], out_ch=ch_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[n], is_last=(n==layer_no-1), num_groups=num_groups, heads=heads, dim_head=dim_head)
)
ups.append(
Up_Block(in_chs=(up_in[n], ch_in[-n-1]), hid_ch=up_in[n], out_ch=up_out[n], time_dim=time_dim, patch_size=None if patch_size is None else patches[layer_no-n-1], is_last=(n==0), num_groups=num_groups, heads=heads, dim_head=dim_head)
)
if n != -1:
conditions.append(
Downsample2D(dim_in=ch_in[n], dim_out=ch_out[n])
)
self.downs = nn.ModuleList(downs)
self.ups = nn.ModuleList(ups)
self.conditions = nn.ModuleList(conditions)
self.mid = MidBlock(in_ch=ch_out[-1], time_dim=time_dim, num_groups=num_groups, heads=heads, dim_head=dim_head)
def forward(self, x, time, conds=None):
t = self.time_mlp(time)
hids1, hids2 = [], []
for n, down_block in enumerate(self.downs):
x, skip1, skip2 = down_block(x, t, conds)
hids1.append(skip1)
hids2.append(skip2)
if conds is not None:
conds = self.conditions[n](conds)
out = self.mid(x, t)
for up_block in self.ups:
out = up_block(out, t, hids1.pop(), hids2.pop())
return out
# constants
from collections import namedtuple
from torch.cuda.amp import autocast
import torch.nn.functional as F
from einops import reduce
from tqdm.auto import tqdm
ModelPrediction = namedtuple('ModelPrediction', ['pred_noise', 'pred_x_start'])
def identity(t, *args, **kwargs):
return t
def extract(a, t, x_shape):
b, *_ = t.shape
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))
def default(val, d):
if exists(val):
return val
return d() if callable(d) else d
def exists(x):
return x is not None
def guidance_scheduler(sampling_step: int, const: float):
return const*torch.ones(sampling_step)
class GaussianDiffusion(nn.Module):
def __init__(
self,
vp_model,
diffusion,
timesteps = 1000,
sampling_timesteps = None,
objective = 'pred_v',
beta_schedule = 'sigmoid',
schedule_fn_kwargs = dict(),
ddim_sampling_eta = 0.,
offset_noise_strength = 0., # https://www.crosslabs.org/blog/diffusion-with-offset-noise
min_snr_loss_weight = False, # https://arxiv.org/abs/2303.09556
min_snr_gamma = 5
):
super(GaussianDiffusion, self).__init__()
self.backbone = vp_model
self.diff_unet = diffusion
self.objective = objective
assert objective in {'pred_noise', 'pred_x0', 'pred_v'}, 'objective must be either pred_noise (predict noise) or pred_x0 (predict image start) or pred_v (predict v [v-parameterization as defined in appendix D of progressive distillation paper, used in imagen-video successfully])'
if beta_schedule == 'linear':
beta_schedule_fn = linear_beta_schedule
elif beta_schedule == 'cosine':
beta_schedule_fn = cosine_beta_schedule
elif beta_schedule == 'sigmoid':
beta_schedule_fn = sigmoid_beta_schedule
else:
raise ValueError(f'unknown beta schedule {beta_schedule}')
betas = beta_schedule_fn(timesteps, **schedule_fn_kwargs)
alphas = 1. - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
alphas_cumprod_prev = F.pad(alphas_cumprod[:-1], (1, 0), value = 1.)
timesteps, = betas.shape
self.num_timesteps = int(timesteps)
# sampling related parameters
self.sampling_timesteps = default(sampling_timesteps, timesteps) # default num sampling timesteps to number of timesteps at training
assert self.sampling_timesteps <= timesteps
self.is_ddim_sampling = self.sampling_timesteps < timesteps
self.ddim_sampling_eta = ddim_sampling_eta
# helper function to register buffer from float64 to float32
register_buffer = lambda name, val: self.register_buffer(name, val.to(torch.float32))
register_buffer('betas', betas)
register_buffer('alphas_cumprod', alphas_cumprod)
register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
# calculations for diffusion q(x_t | x_{t-1}) and others
register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
# calculations for posterior q(x_{t-1} | x_t, x_0)
posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
# above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
register_buffer('posterior_variance', posterior_variance)
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
register_buffer('posterior_log_variance_clipped', torch.log(posterior_variance.clamp(min =1e-20)))
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))
# offset noise strength - in blogpost, they claimed 0.1 was ideal
self.offset_noise_strength = offset_noise_strength
# derive loss weight
# snr - signal noise ratio
snr = alphas_cumprod / (1 - alphas_cumprod)
# https://arxiv.org/abs/2303.09556
maybe_clipped_snr = snr.clone()
if min_snr_loss_weight:
maybe_clipped_snr.clamp_(max = min_snr_gamma)
if objective == 'pred_noise':
register_buffer('loss_weight', maybe_clipped_snr / snr)
elif objective == 'pred_x0':
register_buffer('loss_weight', maybe_clipped_snr)
elif objective == 'pred_v':
register_buffer('loss_weight', maybe_clipped_snr / (snr + 1))
@property
def device(self):
return self.betas.device
# CFG schdeuler => by taking pre-setting scheduler
def setup_guidance(self, scheduler):
if exists(scheduler):
self.CFG_sch = scheduler.to(self.device)
else:
self.CFG_sch = scheduler
def predict_start_from_noise(self, x_t, t, noise):
return (
extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)
def predict_noise_from_start(self, x_t, t, x0):
return (
(extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - x0) / \
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
)
def predict_v(self, x_start, t, noise):
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
)
def predict_start_from_v(self, x_t, t, v):
return (
extract(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
extract(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
)
def q_posterior(self, x_start, x_t, t):
posterior_mean = (
extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
)
posterior_variance = extract(self.posterior_variance, t, x_t.shape)
posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
return posterior_mean, posterior_variance, posterior_log_variance_clipped
def model_predictions(self, x, t, cond, clip_x_start = False, rederive_pred_noise = False):
# print(t.device)
if exists(self.CFG_sch):
uncond = self.diff_unet(x, t, conds=None) #conds=torch.zeros_like(cond))
model_output = self.diff_unet(x, t, conds=cond)
time = int(t[0])
model_output = model_output - self.CFG_sch[time] * (uncond - model_output)
else:
model_output = self.diff_unet(x, t, conds=cond)
maybe_clip = partial(torch.clamp, min = -1., max = 1.) if clip_x_start else identity
if self.objective == 'pred_noise':
pred_noise = model_output
x_start = self.predict_start_from_noise(x, t, pred_noise)
x_start = maybe_clip(x_start)
if clip_x_start and rederive_pred_noise:
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_x0':
x_start = model_output
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
elif self.objective == 'pred_v':
v = model_output
x_start = self.predict_start_from_v(x, t, v)
x_start = maybe_clip(x_start)
pred_noise = self.predict_noise_from_start(x, t, x_start)
return ModelPrediction(pred_noise, x_start)
def p_mean_variance(self, x, t, cond=None, clip_denoised = True):
preds = self.model_predictions(x, t, cond=cond, clip_x_start=False,)
x_start = preds.pred_x_start
if clip_denoised:
x_start.clamp_(-1., 1.)
model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start = x_start, x_t = x, t = t)
return model_mean, posterior_variance, posterior_log_variance, x_start
@torch.no_grad()
def p_sample(self, x, t: int, cond=None):
b, *_, device = *x.shape, self.device
batched_times = torch.full((b,), t, device = device, dtype = torch.long)
model_mean, _, model_log_variance, x_start = self.p_mean_variance(x = x, t = batched_times, cond=cond, clip_denoised = False)
noise = torch.randn_like(x) if t > 0 else 0. # no noise if t == 0
pred_img = model_mean + (0.5 * model_log_variance).exp() * noise
return pred_img, x_start
@torch.no_grad()
def p_sample_loop(self, shape, cond=None, return_all_timesteps = False):
batch, device = shape[0], self.device
frames_pred = torch.randn(shape, device = device)
imgs = [frames_pred]
for t in tqdm(reversed(range(0, self.num_timesteps)), desc = 'sampling loop time step', total = self.num_timesteps, disable=True):
frames_pred, _ = self.p_sample(frames_pred, t, cond=cond)
imgs.append(frames_pred)
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
return ret
@torch.no_grad()
def ddim_sample(self, shape, cond=None, return_all_timesteps = False):
batch, total_timesteps, sampling_timesteps, eta, objective = shape[0], self.num_timesteps, self.sampling_timesteps, self.ddim_sampling_eta, self.objective
device = self.device
times = torch.linspace(-1, total_timesteps - 1, steps = sampling_timesteps + 1) # [-1, 0, 1, 2, ..., T-1] when sampling_timesteps == total_timesteps
times = list(reversed(times.int().tolist()))
time_pairs = list(zip(times[:-1], times[1:])) # [(T-1, T-2), (T-2, T-3), ..., (1, 0), (0, -1)]
frames_pred = torch.randn(shape, device = device)
imgs = [frames_pred]
for time, time_next in tqdm(time_pairs, desc = 'sampling loop time step', disable=True):
time_cond = torch.full((batch,), time, device = device, dtype = torch.long)
pred_noise, x_start, *_ = self.model_predictions(
frames_pred,
time_cond,
cond = cond, #cond.copy(),
clip_x_start = False,
rederive_pred_noise = True
)
if time_next < 0:
frames_pred = x_start
imgs.append(frames_pred)
continue
alpha = self.alphas_cumprod[time]
alpha_next = self.alphas_cumprod[time_next]
sigma = eta * ((1 - alpha / alpha_next) * (1 - alpha_next) / (1 - alpha)).sqrt()
c = (1 - alpha_next - sigma ** 2).sqrt()
noise = torch.randn_like(frames_pred)
frames_pred = x_start * alpha_next.sqrt() + \
c * pred_noise + \
sigma * noise
imgs.append(frames_pred)
ret = frames_pred if not return_all_timesteps else torch.stack(imgs, dim = 1)
return ret
@torch.no_grad()
def sample(self, frames_in, return_all_timesteps = False):
assert frames_in.ndim == 5
B, T_in, C, H, W = frames_in.shape
device = self.device
backbone_output, conds, *_ = self.backbone(frames_in) # updated for Updated loss function on 03/07
sample_fn = self.p_sample_loop if not self.is_ddim_sampling else self.ddim_sample
*_, c, h, w = conds.shape
tgt_shape = conds.reshape(B, -1, c, h, w).shape
ldm_pred = sample_fn(
tgt_shape,
cond=conds,
return_all_timesteps = return_all_timesteps
)
ldm_pred = rearrange(ldm_pred, 'b t c h w -> (b t) c h w')
frames_pred = self.backbone.vae.decode(ldm_pred)
frames_pred = rearrange(frames_pred, '(b t) c h w -> b t c h w', b=B)
return frames_pred, backbone_output
def predict(self, frames_in, compute_loss=False, **kwargs):
pred, mu = self.sample(frames_in=frames_in)
return pred, mu
def compute_loss(self, frames_in, frames_gt, validate=False):
compute_loss = True and (not validate)
B, T_in, C, H, W = frames_in.shape
T_out = frames_gt.shape[1]
device = frames_in.device
"""
Diffusion Loss
"""
backbone_output, conds = self.backbone(frames_in)
hid_gt, _ = self.backbone.vae.encode(
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
)
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
t = torch.randint(0, self.num_timesteps, (B,), device=self.device).long()
if random.random() > 0.85: # Unconditional
conds = None
diff_loss = self.p_losses(hid_gt.detach(), t, cond=conds)
"""
Backbone Loss
"""
mu_loss = self.backbone._losses_(frames_in, frames_gt)
"""
VAE Loss
"""
ae_loss, kl_loss = self.backbone.vae._losses_(
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w'),
rearrange(torch.cat((frames_in, frames_gt), dim=1), 'b t c h w -> (b t) c h w')
)
kl_weight = 1E-6
recon_loss = ae_loss + kl_weight*kl_loss
"""
Prior Loss at t=T [Noisy]
"""
hid_gt, _ = self.backbone.vae.encode(
rearrange(frames_gt, 'b t c h w -> (b t) c h w')
)
hid_gt = rearrange(hid_gt, '(b t) c h w -> b t c h w', b=B)
T = torch.ones((B,), device=self.device).long() * (self.num_timesteps - 1)
mu_noisy = extract(self.sqrt_alphas_cumprod, T, hid_gt.shape) * hid_gt
sigma_noisy = extract(self.sqrt_one_minus_alphas_cumprod, T, hid_gt.shape)
log_var_noisy = 2*torch.log(sigma_noisy)
prior_loss = self.kl_from_standard_normal(mu_noisy, log_var_noisy)
return recon_loss, mu_loss, diff_loss, prior_loss
def kl_from_standard_normal(self, mean, log_var):
kl = 0.5 * (log_var.exp() + mean.square() - 1.0 - log_var)
return kl.mean()
@autocast(enabled = False)
def q_sample(self, x_start, t, noise = None):
noise = default(noise, lambda: torch.randn_like(x_start))
return (
extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
)
def p_losses(self, x_start, t, noise=None, offset_noise_strength=None, cond=None):
b, T, c, h, w = x_start.shape
noise = default(noise, lambda: torch.randn_like(x_start))
# offset noise - https://www.crosslabs.org/blog/diffusion-with-offset-noise
offset_noise_strength = default(offset_noise_strength, self.offset_noise_strength)
if offset_noise_strength > 0.:
offset_noise = torch.randn(x_start.shape[:2], device = self.device)
noise += offset_noise_strength * rearrange(offset_noise, 'b c -> b c 1 1')
# noise sample
x = self.q_sample(x_start=x_start, t=t, noise=noise) # Use q_sample here for updating: https://github.com/lucidrains/denoising-diffusion-pytorch/blob/main/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py#L763
model_out = self.diff_unet(x, t, conds=cond)
if self.objective == 'pred_noise':
target = noise
elif self.objective == 'pred_x0':
target = x_start
elif self.objective == 'pred_v':
v = self.predict_v(x_start, t, noise)
target = v
else:
raise ValueError(f'unknown objective {self.objective}')
loss = F.mse_loss(model_out, target, reduction = 'none') # (B, T, C, H, W)
loss = reduce(loss, 'b ... -> b', 'mean')
loss = loss * extract(self.loss_weight, t, loss.shape)
return loss.mean()
@torch.no_grad()
def forward(self, input_x, include_mu=False, **kwargs):
pred, mu = self.predict(input_x, compute_loss=False)
if include_mu:
return pred, mu
else:
return pred
from stldm.modules import SimVPV2_Model, VAE
def model_setup(model_config, print_info=False, cfg_str=None):
if print_info:
print('Setup a Spatial diffusion with replacing a Temporal attention with Spatial attention')
print('This is a diffusion with the consideration of (BT, C, H, W)')
print('Train it from end to end')
vp_config = model_config['vp_param']
ldm_config = model_config['stldm_param']
vpm = SimVPV2_Model(**vp_config)
ldm = LDM(**ldm_config)
model = GaussianDiffusion(vp_model=vpm, diffusion=ldm, **model_config['param'])
scheduler = guidance_scheduler(sampling_step=model_config['param']['timesteps'], const=cfg_str) if cfg_str is not None else None
model.setup_guidance(scheduler)
return model
def ae_setup(model_config):
vp_config = model_config['vp_param']
vpm = SimVPV2_Model(**vp_config)
ae = vpm.vae
return ae
def backbone_setup(model_config):
vp_config = model_config['vp_param']
vpm = SimVPV2_Model(**vp_config)
return vpm