| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from .tools.wan_vae import WanVAE_ |
|
|
|
|
| class VAEWanModel(nn.Module): |
| def __init__( |
| self, |
| input_dim, |
| mean_path=None, |
| std_path=None, |
| z_dim=256, |
| dim=160, |
| dec_dim=512, |
| num_res_blocks=1, |
| dropout=0.0, |
| dim_mult=[1, 1, 1], |
| temperal_downsample=[True, True], |
| spatial_downsample=[False, False], |
| spatial_dim=0, |
| input_keys={ |
| "feature": "feature", |
| "feature_length": "feature_length", |
| }, |
| **kwargs, |
| ): |
| super().__init__() |
| self.input_keys = input_keys |
|
|
| self.mean_path = mean_path |
| self.std_path = std_path |
| self.input_dim = input_dim |
| self.z_dim = z_dim |
| self.dim = dim |
| self.dec_dim = dec_dim |
| self.num_res_blocks = num_res_blocks |
| self.dropout = dropout |
| self.dim_mult = dim_mult |
| self.temperal_downsample = temperal_downsample |
| self.spatial_downsample = spatial_downsample |
| self.spatial_dim = spatial_dim |
| self.RECONS_LOSS = nn.SmoothL1Loss(reduction="none") |
| self.LAMBDA_FEATURE = kwargs.get("LAMBDA_FEATURE", 1.0) |
| self.LAMBDA_KL = kwargs.get("LAMBDA_KL", 10e-6) |
|
|
| |
| |
| recons_weights = kwargs.get("recons_weights", None) |
| if recons_weights is not None: |
| w = torch.tensor(recons_weights, dtype=torch.float32) |
| if w.numel() < input_dim: |
| w = torch.cat([w, torch.ones(input_dim - w.numel())]) |
| self.register_buffer("recons_weights", w[:input_dim], persistent=False) |
| else: |
| self.register_buffer( |
| "recons_weights", torch.ones(input_dim, dtype=torch.float32), persistent=False |
| ) |
|
|
| if self.mean_path is not None: |
| self.register_buffer( |
| "mean", torch.from_numpy(np.load(self.mean_path)).float() |
| ) |
| else: |
| self.register_buffer("mean", torch.zeros(input_dim)) |
|
|
| if self.std_path is not None: |
| self.register_buffer( |
| "std", torch.from_numpy(np.load(self.std_path)).float() |
| ) |
| else: |
| self.register_buffer("std", torch.ones(input_dim)) |
|
|
| self.model = WanVAE_( |
| input_dim=self.input_dim, |
| dim=self.dim, |
| dec_dim=self.dec_dim, |
| z_dim=self.z_dim, |
| dim_mult=self.dim_mult, |
| num_res_blocks=self.num_res_blocks, |
| temperal_downsample=self.temperal_downsample, |
| spatial_downsample=self.spatial_downsample, |
| spatial_dim=self.spatial_dim, |
| dropout=self.dropout, |
| ) |
|
|
| downsample_factor = 1 |
| for flag in self.temperal_downsample: |
| if flag: |
| downsample_factor *= 2 |
| self.downsample_factor = downsample_factor |
|
|
| def _extract_inputs(self, x): |
| inputs = {} |
| for internal_key, external_key in self.input_keys.items(): |
| if external_key in x: |
| inputs[internal_key] = x[external_key] |
| return inputs |
|
|
| def preprocess(self, x): |
| """Convert last-channel batched format to channel-first, padding to 5D (B, C, T, H, W). |
| (B, T, C) -> (B, C, T, 1, 1) |
| (B, T, H, C) -> (B, C, T, H, 1) |
| (B, T, H, W, C) -> (B, C, T, H, W) |
| """ |
| ndim = x.ndim |
| if ndim == 3: |
| x = x.permute(0, 2, 1)[:, :, :, None, None] |
| elif ndim == 4: |
| x = x.permute(0, 3, 1, 2)[:, :, :, :, None] |
| elif ndim == 5: |
| x = x.permute(0, 4, 1, 2, 3) |
| return x |
|
|
| def postprocess(self, x): |
| """Reverse of preprocess: channel-first 5D back to last-channel, stripping padding dims. |
| (B, C, T, 1, 1) -> (B, T, C) |
| (B, C, T, H, 1) -> (B, T, H, C) |
| (B, C, T, H, W) -> (B, T, H, W, C) |
| """ |
| shape = x.shape |
| if shape[3] == 1 and shape[4] == 1: |
| x = x[:, :, :, 0, 0].permute(0, 2, 1) |
| elif shape[4] == 1: |
| x = x[:, :, :, :, 0].permute(0, 2, 3, 1) |
| else: |
| x = x.permute(0, 2, 3, 4, 1) |
| return x |
|
|
| def forward(self, x): |
| x = self._extract_inputs(x) |
| features = x["feature"] |
| feature_length = x["feature_length"] |
| features = (features - self.mean) / self.std |
| |
| batch_size, seq_len = features.shape[:2] |
| mask = torch.zeros( |
| batch_size, seq_len, dtype=torch.bool, device=features.device |
| ) |
| for i in range(batch_size): |
| mask[i, : feature_length[i]] = True |
|
|
| x_in = self.preprocess(features) |
| mu, log_var = self.model.encode( |
| x_in, scale=[0, 1], return_dist=True |
| ) |
| z = self.model.reparameterize(mu, log_var) |
| x_decoder = self.model.decode(z, scale=[0, 1]) |
| x_out = self.postprocess(x_decoder) |
|
|
| if x_out.size(1) != features.size(1): |
| min_len = min(x_out.size(1), features.size(1)) |
| x_out = x_out[:, :min_len] |
| features = features[:, :min_len] |
| mask = mask[:, :min_len] |
|
|
| mask_expanded = mask |
| for _ in range(features.ndim - 2): |
| mask_expanded = mask_expanded.unsqueeze(-1) |
| loss_per_element = self.RECONS_LOSS(x_out, features) |
| loss_recons = (loss_per_element * mask_expanded * self.recons_weights).sum() / mask_expanded.sum() / self.recons_weights.sum() |
|
|
| |
| |
| |
|
|
| |
| T_latent = mu.size(2) |
| mask_downsampled = torch.zeros( |
| batch_size, T_latent, dtype=torch.bool, device=features.device |
| ) |
| for i in range(batch_size): |
| latent_length = ( |
| feature_length[i] + self.downsample_factor - 1 |
| ) // self.downsample_factor |
| mask_downsampled[i, :latent_length] = True |
| mask_latent = ( |
| mask_downsampled.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) |
| ) |
|
|
| |
| kl_per_element = -0.5 * (1 + log_var - mu.pow(2) - log_var.exp()) |
| |
| kl_masked = kl_per_element * mask_latent |
| |
| num_latent_elements = mu.size(1) * mu.size(3) * mu.size(4) |
| kl_loss = torch.sum(kl_masked) / ( |
| torch.sum(mask_downsampled) * num_latent_elements |
| ) |
|
|
| |
| total_loss = ( |
| self.LAMBDA_FEATURE * loss_recons |
| + self.LAMBDA_KL * kl_loss |
| ) |
|
|
| loss_dict = {} |
| loss_dict["total"] = total_loss |
| loss_dict["recons"] = loss_recons |
| loss_dict["kl"] = kl_loss |
|
|
| return loss_dict |
|
|
| def encode(self, x): |
| x = (x - self.mean) / self.std |
| x_in = self.preprocess(x) |
| mu = self.model.encode(x_in, scale=[0, 1]) |
| mu = self.postprocess(mu) |
| return mu |
|
|
| def decode(self, mu): |
| mu_in = self.preprocess(mu) |
| x_decoder = self.model.decode(mu_in, scale=[0, 1]) |
| x_out = self.postprocess(x_decoder) |
| x_out = x_out * self.std + self.mean |
| return x_out |
|
|
| @torch.no_grad() |
| def stream_encode(self, x, first_chunk=True): |
| x = (x - self.mean) / self.std |
| x_in = self.preprocess(x) |
| mu = self.model.stream_encode(x_in, first_chunk=first_chunk, scale=[0, 1]) |
| mu = self.postprocess(mu) |
| return mu |
|
|
| @torch.no_grad() |
| def stream_decode(self, mu, first_chunk=True): |
| mu_in = self.preprocess(mu) |
| x_decoder = self.model.stream_decode( |
| mu_in, first_chunk=first_chunk, scale=[0, 1] |
| ) |
| x_out = self.postprocess(x_decoder) |
| x_out = x_out * self.std + self.mean |
| return x_out |
|
|
| def clear_cache(self): |
| self.model.clear_cache() |
|
|
| def generate(self, x): |
| x = self._extract_inputs(x) |
| features = x["feature"] |
| feature_length = x["feature_length"] |
| y_hat = self.decode(self.encode(features)) |
|
|
| y_hat_out = [] |
|
|
| for i in range(y_hat.shape[0]): |
| |
| valid_len = ( |
| feature_length[i] - 1 |
| ) // self.downsample_factor * self.downsample_factor + 1 |
| |
| y_hat_out.append(y_hat[i, :valid_len]) |
|
|
| out = {} |
| out["generated"] = y_hat_out |
| return out |
|
|