import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from .tools.wan_vae import WanVAE_ class VAEWanModel(nn.Module): def __init__( self, input_dim, mean_path=None, std_path=None, z_dim=256, dim=160, dec_dim=512, num_res_blocks=1, dropout=0.0, dim_mult=[1, 1, 1], temperal_downsample=[True, True], spatial_downsample=[False, False], spatial_dim=0, input_keys={ "feature": "feature", "feature_length": "feature_length", }, **kwargs, ): super().__init__() self.input_keys = input_keys self.mean_path = mean_path self.std_path = std_path self.input_dim = input_dim self.z_dim = z_dim self.dim = dim self.dec_dim = dec_dim self.num_res_blocks = num_res_blocks self.dropout = dropout self.dim_mult = dim_mult self.temperal_downsample = temperal_downsample self.spatial_downsample = spatial_downsample self.spatial_dim = spatial_dim self.RECONS_LOSS = nn.SmoothL1Loss(reduction="none") self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0) self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6) # Per-dimension reconstruction weights (default: all ones) # If shorter than input_dim, pad with 1s at the end. recons_weights = kwargs.get("recons_weights", None) if recons_weights is not None: w = torch.tensor(recons_weights, dtype=torch.float32) if w.numel() < input_dim: w = torch.cat([w, torch.ones(input_dim - w.numel())]) self.register_buffer("recons_weights", w[:input_dim], persistent=False) else: self.register_buffer( "recons_weights", torch.ones(input_dim, dtype=torch.float32), persistent=False ) if self.mean_path is not None: self.register_buffer( "mean", torch.from_numpy(np.load(self.mean_path)).float() ) else: self.register_buffer("mean", torch.zeros(input_dim)) if self.std_path is not None: self.register_buffer( "std", torch.from_numpy(np.load(self.std_path)).float() ) else: self.register_buffer("std", torch.ones(input_dim)) self.model = WanVAE_( input_dim=self.input_dim, dim=self.dim, dec_dim=self.dec_dim, z_dim=self.z_dim, dim_mult=self.dim_mult, num_res_blocks=self.num_res_blocks, temperal_downsample=self.temperal_downsample, spatial_downsample=self.spatial_downsample, spatial_dim=self.spatial_dim, dropout=self.dropout, ) downsample_factor = 1 for flag in self.temperal_downsample: if flag: downsample_factor *= 2 self.downsample_factor = downsample_factor def _extract_inputs(self, x): inputs = {} for internal_key, external_key in self.input_keys.items(): if external_key in x: inputs[internal_key] = x[external_key] return inputs def preprocess(self, x): """Convert last-channel batched format to channel-first, padding to 5D (B, C, T, H, W). (B, T, C) -> (B, C, T, 1, 1) (B, T, H, C) -> (B, C, T, H, 1) (B, T, H, W, C) -> (B, C, T, H, W) """ ndim = x.ndim if ndim == 3: # (B, T, C) x = x.permute(0, 2, 1)[:, :, :, None, None] elif ndim == 4: # (B, T, H, C) x = x.permute(0, 3, 1, 2)[:, :, :, :, None] elif ndim == 5: # (B, T, H, W, C) x = x.permute(0, 4, 1, 2, 3) return x def postprocess(self, x): """Reverse of preprocess: channel-first 5D back to last-channel, stripping padding dims. (B, C, T, 1, 1) -> (B, T, C) (B, C, T, H, 1) -> (B, T, H, C) (B, C, T, H, W) -> (B, T, H, W, C) """ shape = x.shape # (B, C, T, H, W) if shape[3] == 1 and shape[4] == 1: # (B, C, T, 1, 1) -> (B, T, C) x = x[:, :, :, 0, 0].permute(0, 2, 1) elif shape[4] == 1: # (B, C, T, H, 1) -> (B, T, H, C) x = x[:, :, :, :, 0].permute(0, 2, 3, 1) else: # (B, C, T, H, W) -> (B, T, H, W, C) x = x.permute(0, 2, 3, 4, 1) return x def forward(self, x): x = self._extract_inputs(x) features = x["feature"] feature_length = x["feature_length"] features = (features - self.mean) / self.std # create mask based on feature_length batch_size, seq_len = features.shape[:2] mask = torch.zeros( batch_size, seq_len, dtype=torch.bool, device=features.device ) for i in range(batch_size): mask[i, : feature_length[i]] = True x_in = self.preprocess(features) # (bs, input_dim, T, 1, 1) mu, log_var = self.model.encode( x_in, scale=[0, 1], return_dist=True ) # (bs, z_dim, T, 1, 1) z = self.model.reparameterize(mu, log_var) x_decoder = self.model.decode(z, scale=[0, 1]) # (bs, input_dim, T, 1, 1) x_out = self.postprocess(x_decoder) # (bs, T, input_dim) if x_out.size(1) != features.size(1): min_len = min(x_out.size(1), features.size(1)) x_out = x_out[:, :min_len] features = features[:, :min_len] mask = mask[:, :min_len] mask_expanded = mask for _ in range(features.ndim - 2): mask_expanded = mask_expanded.unsqueeze(-1) loss_per_element = self.RECONS_LOSS(x_out, features) loss_recons = (loss_per_element * mask_expanded * self.recons_weights).sum() / mask_expanded.sum() / self.recons_weights.sum() # Compute KL divergence loss # KL(N(mu, sigma) || N(0, 1)) = -0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) # log_var = log(sigma^2), so we can use it directly # Build mask for latent space T_latent = mu.size(2) mask_downsampled = torch.zeros( batch_size, T_latent, dtype=torch.bool, device=features.device ) for i in range(batch_size): latent_length = ( feature_length[i] + self.downsample_factor - 1 ) // self.downsample_factor mask_downsampled[i, :latent_length] = True mask_latent = ( mask_downsampled.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) ) # (B, 1, T_latent, 1, 1) # Compute KL loss per element kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) # Apply mask: only compute KL loss for valid timesteps kl_masked = kl_per_element * mask_latent # Sum over all dimensions and normalize by the number of valid elements num_latent_elements = mu.size(1) * mu.size(3) * mu.size(4) # C * H * W kl_loss = torch.sum(kl_masked) / ( torch.sum(mask_downsampled) * num_latent_elements ) # normalize by valid timesteps * (C * H * W) # Total loss total_loss = ( self.LAMBDA_FEATURE * loss_recons + self.LAMBDA_KL * kl_loss ) loss_dict = {} loss_dict["total"] = total_loss loss_dict["recons"] = loss_recons loss_dict["kl"] = kl_loss return loss_dict def encode(self, x): x = (x - self.mean) / self.std x_in = self.preprocess(x) # (bs, T, input_dim) -> (bs, input_dim, T, 1, 1) mu = self.model.encode(x_in, scale=[0, 1]) # (bs, z_dim, T, 1, 1) mu = self.postprocess(mu) # (bs, T, z_dim) return mu def decode(self, mu): mu_in = self.preprocess(mu) # (bs, T, z_dim) -> (bs, z_dim, T, 1, 1) x_decoder = self.model.decode(mu_in, scale=[0, 1]) # (bs, z_dim, T, 1, 1) x_out = self.postprocess(x_decoder) # (bs, T, input_dim) x_out = x_out * self.std + self.mean return x_out @torch.no_grad() def stream_encode(self, x, first_chunk=True): x = (x - self.mean) / self.std x_in = self.preprocess(x) # (bs, input_dim, T, 1, 1) mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1]) mu = self.postprocess(mu) # (bs, T, z_dim) return mu @torch.no_grad() def stream_decode(self, mu, first_chunk=True): mu_in = self.preprocess(mu) # (bs, z_dim, T, 1, 1) x_decoder = self.model.stream_decode( mu_in, first_chunk=first_chunk, scale=[0, 1] ) x_out = self.postprocess(x_decoder) # (bs, T, input_dim) x_out = x_out * self.std + self.mean return x_out def clear_cache(self): self.model.clear_cache() def generate(self, x): x = self._extract_inputs(x) features = x["feature"] feature_length = x["feature_length"] y_hat = self.decode(self.encode(features)) y_hat_out = [] for i in range(y_hat.shape[0]): # cut off the padding and align lengths valid_len = ( feature_length[i] - 1 ) // self.downsample_factor * self.downsample_factor + 1 # Make sure both have the same length (take minimum) y_hat_out.append(y_hat[i, :valid_len]) out = {} out["generated"] = y_hat_out return out