Spaces:
Running on Zero
Running on Zero
| from typing import Tuple | |
| import torch | |
| from model.base import BaseModel | |
| from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper | |
| class MMDiffusion(BaseModel): | |
| def __init__(self, args, device): | |
| """ | |
| Initialize the Diffusion loss module. | |
| """ | |
| super().__init__(args, device) | |
| self.num_frame_per_block = getattr(args, "num_frame_per_block", 1) | |
| if self.num_frame_per_block > 1: | |
| self.generator.model.num_frame_per_block = self.num_frame_per_block | |
| self.independent_first_frame = getattr(args, "independent_first_frame", False) | |
| if self.independent_first_frame: | |
| self.generator.model.independent_first_frame = True | |
| if args.gradient_checkpointing: | |
| self.generator.enable_gradient_checkpointing() | |
| # Step 2: Initialize all hyperparameters | |
| self.num_train_timestep = args.num_train_timestep | |
| self.min_step = int(0.02 * self.num_train_timestep) | |
| self.max_step = int(0.98 * self.num_train_timestep) | |
| self.guidance_scale = args.guidance_scale | |
| self.timestep_shift = getattr(args, "timestep_shift", 1.0) | |
| self.teacher_forcing = getattr(args, "teacher_forcing", False) | |
| # Noise augmentation in teacher forcing, we add small noise to clean context latents | |
| self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0) | |
| def _initialize_models(self, args, device): | |
| self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True) | |
| self.generator.model.requires_grad_(True) | |
| print("Initialized WanDiffusionWrapper for MMDiffusion is_causal =", getattr(args, "is_causal", True)) | |
| self._expand_input_layer(args, device) | |
| self.text_encoder = WanTextEncoder() | |
| self.text_encoder.requires_grad_(False) | |
| self.vae = WanVAEWrapper() | |
| self.vae.requires_grad_(False) | |
| self.scheduler = self.generator.get_scheduler() | |
| self.scheduler.timesteps = self.scheduler.timesteps.to(device) | |
| def _expand_input_layer(self, args, device): | |
| """ | |
| 将 Wan 模型的 patch_embedding 从 16 通道扩展到 32 通道 | |
| """ | |
| # 获取原始的 patch_embedding (nn.Conv3d) | |
| old_proj = self.generator.model.patch_embedding | |
| # 检查是否已经是 32 通道,避免重复修改(例如在 resume 训练时) | |
| if old_proj.in_channels == 32: | |
| return | |
| print(f"Expanding patch_embedding input channels: {old_proj.in_channels} -> 32") | |
| # 创建新的卷积层 | |
| # Wan 的参数通常是: in_channels=16, out_channels=1536, kernel=(1,2,2), stride=(1,2,2) | |
| new_proj = torch.nn.Conv3d( | |
| in_channels=32, | |
| out_channels=old_proj.out_channels, | |
| # out_channels=old_proj.out_channels*2, | |
| kernel_size=old_proj.kernel_size, | |
| stride=old_proj.stride, | |
| padding=old_proj.padding, | |
| bias=(old_proj.bias is not None) | |
| ).to(device=device, dtype=old_proj.weight.dtype) | |
| # 权重初始化策略: | |
| # 1. 前 16 通道拷贝原有的预训练权重(保证模型还记得怎么处理噪声视频) | |
| # 2. 后 16 通道(Source 视频输入)初始化为 0 | |
| # 这样在训练初期,Source 视频不会对预测产生干扰,模型可以平滑地开始学习编辑逻辑 | |
| with torch.no_grad(): | |
| new_proj.weight.zero_() | |
| new_proj.weight[:, :16].copy_(old_proj.weight) | |
| if old_proj.bias is not None: | |
| new_proj.bias.copy_(old_proj.bias) | |
| # 替换原有的层 | |
| self.generator.model.patch_embedding = new_proj | |
| def generator_loss( | |
| self, | |
| conditional_dict: dict, | |
| target_latent: torch.Tensor, | |
| ) -> Tuple[torch.Tensor, dict]: | |
| """ | |
| Generate image/videos from noise and compute the DMD loss. | |
| The noisy input to the generator is backward simulated. | |
| This removes the need of any datasets during distillation. | |
| See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details. | |
| Input: | |
| - image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W]. | |
| - conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings). | |
| - unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings). | |
| - clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used. | |
| Output: | |
| - loss: a scalar tensor representing the generator loss. | |
| - generator_log_dict: a dictionary containing the intermediate tensors for logging. | |
| """ | |
| noise = torch.randn_like(target_latent) | |
| image_or_video_shape = target_latent.shape | |
| batch_size, num_frame = image_or_video_shape[:2] | |
| source_latent = conditional_dict["source_latent"] | |
| # Step 2: Randomly sample a timestep and add noise to denoiser inputs | |
| index = self._get_timestep( | |
| 0, | |
| self.scheduler.num_train_timesteps, | |
| image_or_video_shape[0], | |
| image_or_video_shape[1], | |
| self.num_frame_per_block, | |
| uniform_timestep=False | |
| ) | |
| timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device) | |
| noisy_latents = self.scheduler.add_noise( | |
| target_latent.flatten(0, 1), | |
| noise.flatten(0, 1), | |
| timestep.flatten(0, 1) | |
| ).unflatten(0, (batch_size, num_frame)) | |
| training_target = self.scheduler.training_target(target_latent, noise, timestep) | |
| # Step 3: Noise augmentation, also add small noise to clean context latents | |
| if self.noise_augmentation_max_timestep > 0: | |
| index_clean_aug = self._get_timestep( | |
| 0, | |
| self.noise_augmentation_max_timestep, | |
| image_or_video_shape[0], | |
| image_or_video_shape[1], | |
| self.num_frame_per_block, | |
| uniform_timestep=False | |
| ) | |
| timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device) | |
| clean_latent_aug = self.scheduler.add_noise( | |
| target_latent.flatten(0, 1), | |
| noise.flatten(0, 1), | |
| timestep_clean_aug.flatten(0, 1) | |
| ).unflatten(0, (batch_size, num_frame)) | |
| else: | |
| clean_latent_aug = target_latent | |
| timestep_clean_aug = None | |
| # Compute loss | |
| flow_pred, x0_pred = self.generator( | |
| noisy_image_or_video=noisy_latents, | |
| conditional_dict=conditional_dict, | |
| timestep=timestep, | |
| y=source_latent, # [B,T,C,H,W] | |
| ) | |
| loss = torch.nn.functional.mse_loss( | |
| flow_pred.float(), training_target.float(), reduction='none' | |
| ).mean(dim=(2, 3, 4)) | |
| loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame)) | |
| loss = loss.mean() | |
| log_dict = { | |
| "x0": target_latent.detach(), | |
| "x0_pred": x0_pred.detach() | |
| } | |
| return loss, log_dict | |