LiveEdit / model /mm_diffusion.py
multimodalart's picture
multimodalart HF Staff
Upload folder using huggingface_hub
70d51ed verified
Raw
History Blame Contribute Delete
7.43 kB
from typing import Tuple
import torch
from model.base import BaseModel
from utils.wan_wrapper import WanDiffusionWrapper, WanTextEncoder, WanVAEWrapper
class MMDiffusion(BaseModel):
def __init__(self, args, device):
"""
Initialize the Diffusion loss module.
"""
super().__init__(args, device)
self.num_frame_per_block = getattr(args, "num_frame_per_block", 1)
if self.num_frame_per_block > 1:
self.generator.model.num_frame_per_block = self.num_frame_per_block
self.independent_first_frame = getattr(args, "independent_first_frame", False)
if self.independent_first_frame:
self.generator.model.independent_first_frame = True
if args.gradient_checkpointing:
self.generator.enable_gradient_checkpointing()
# Step 2: Initialize all hyperparameters
self.num_train_timestep = args.num_train_timestep
self.min_step = int(0.02 * self.num_train_timestep)
self.max_step = int(0.98 * self.num_train_timestep)
self.guidance_scale = args.guidance_scale
self.timestep_shift = getattr(args, "timestep_shift", 1.0)
self.teacher_forcing = getattr(args, "teacher_forcing", False)
# Noise augmentation in teacher forcing, we add small noise to clean context latents
self.noise_augmentation_max_timestep = getattr(args, "noise_augmentation_max_timestep", 0)
def _initialize_models(self, args, device):
self.generator = WanDiffusionWrapper(**getattr(args, "model_kwargs", {}), is_causal=True)
self.generator.model.requires_grad_(True)
print("Initialized WanDiffusionWrapper for MMDiffusion is_causal =", getattr(args, "is_causal", True))
self._expand_input_layer(args, device)
self.text_encoder = WanTextEncoder()
self.text_encoder.requires_grad_(False)
self.vae = WanVAEWrapper()
self.vae.requires_grad_(False)
self.scheduler = self.generator.get_scheduler()
self.scheduler.timesteps = self.scheduler.timesteps.to(device)
def _expand_input_layer(self, args, device):
"""
将 Wan 模型的 patch_embedding 从 16 通道扩展到 32 通道
"""
# 获取原始的 patch_embedding (nn.Conv3d)
old_proj = self.generator.model.patch_embedding
# 检查是否已经是 32 通道,避免重复修改(例如在 resume 训练时)
if old_proj.in_channels == 32:
return
print(f"Expanding patch_embedding input channels: {old_proj.in_channels} -> 32")
# 创建新的卷积层
# Wan 的参数通常是: in_channels=16, out_channels=1536, kernel=(1,2,2), stride=(1,2,2)
new_proj = torch.nn.Conv3d(
in_channels=32,
out_channels=old_proj.out_channels,
# out_channels=old_proj.out_channels*2,
kernel_size=old_proj.kernel_size,
stride=old_proj.stride,
padding=old_proj.padding,
bias=(old_proj.bias is not None)
).to(device=device, dtype=old_proj.weight.dtype)
# 权重初始化策略:
# 1. 前 16 通道拷贝原有的预训练权重(保证模型还记得怎么处理噪声视频)
# 2. 后 16 通道(Source 视频输入)初始化为 0
# 这样在训练初期,Source 视频不会对预测产生干扰,模型可以平滑地开始学习编辑逻辑
with torch.no_grad():
new_proj.weight.zero_()
new_proj.weight[:, :16].copy_(old_proj.weight)
if old_proj.bias is not None:
new_proj.bias.copy_(old_proj.bias)
# 替换原有的层
self.generator.model.patch_embedding = new_proj
def generator_loss(
self,
conditional_dict: dict,
target_latent: torch.Tensor,
) -> Tuple[torch.Tensor, dict]:
"""
Generate image/videos from noise and compute the DMD loss.
The noisy input to the generator is backward simulated.
This removes the need of any datasets during distillation.
See Sec 4.5 of the DMD2 paper (https://arxiv.org/abs/2405.14867) for details.
Input:
- image_or_video_shape: a list containing the shape of the image or video [B, F, C, H, W].
- conditional_dict: a dictionary containing the conditional information (e.g. text embeddings, image embeddings).
- unconditional_dict: a dictionary containing the unconditional information (e.g. null/negative text embeddings, null/negative image embeddings).
- clean_latent: a tensor containing the clean latents [B, F, C, H, W]. Need to be passed when no backward simulation is used.
Output:
- loss: a scalar tensor representing the generator loss.
- generator_log_dict: a dictionary containing the intermediate tensors for logging.
"""
noise = torch.randn_like(target_latent)
image_or_video_shape = target_latent.shape
batch_size, num_frame = image_or_video_shape[:2]
source_latent = conditional_dict["source_latent"]
# Step 2: Randomly sample a timestep and add noise to denoiser inputs
index = self._get_timestep(
0,
self.scheduler.num_train_timesteps,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=False
)
timestep = self.scheduler.timesteps[index].to(dtype=self.dtype, device=self.device)
noisy_latents = self.scheduler.add_noise(
target_latent.flatten(0, 1),
noise.flatten(0, 1),
timestep.flatten(0, 1)
).unflatten(0, (batch_size, num_frame))
training_target = self.scheduler.training_target(target_latent, noise, timestep)
# Step 3: Noise augmentation, also add small noise to clean context latents
if self.noise_augmentation_max_timestep > 0:
index_clean_aug = self._get_timestep(
0,
self.noise_augmentation_max_timestep,
image_or_video_shape[0],
image_or_video_shape[1],
self.num_frame_per_block,
uniform_timestep=False
)
timestep_clean_aug = self.scheduler.timesteps[index_clean_aug].to(dtype=self.dtype, device=self.device)
clean_latent_aug = self.scheduler.add_noise(
target_latent.flatten(0, 1),
noise.flatten(0, 1),
timestep_clean_aug.flatten(0, 1)
).unflatten(0, (batch_size, num_frame))
else:
clean_latent_aug = target_latent
timestep_clean_aug = None
# Compute loss
flow_pred, x0_pred = self.generator(
noisy_image_or_video=noisy_latents,
conditional_dict=conditional_dict,
timestep=timestep,
y=source_latent, # [B,T,C,H,W]
)
loss = torch.nn.functional.mse_loss(
flow_pred.float(), training_target.float(), reduction='none'
).mean(dim=(2, 3, 4))
loss = loss * self.scheduler.training_weight(timestep).unflatten(0, (batch_size, num_frame))
loss = loss.mean()
log_dict = {
"x0": target_latent.detach(),
"x0_pred": x0_pred.detach()
}
return loss, log_dict