| |
| import gc |
| import logging |
| import math |
| import os |
| import random |
| import sys |
| import types |
| from contextlib import contextmanager |
| from copy import deepcopy |
| from functools import partial |
|
|
| import numpy as np |
| import torch |
| import torch.cuda.amp as amp |
| import torch.distributed as dist |
| import torchvision.transforms.functional as TF |
| from decord import VideoReader |
| from PIL import Image |
| from safetensors import safe_open |
| from torchvision import transforms |
| from tqdm import tqdm |
|
|
| from .distributed.fsdp import shard_model |
| from .distributed.sequence_parallel import sp_attn_forward, sp_dit_forward |
| from .distributed.util import get_world_size |
| from .modules.s2v.audio_encoder import AudioEncoder |
| from .modules.s2v.model_s2v import WanModel_S2V, sp_attn_forward_s2v |
| from .modules.t5 import T5EncoderModel |
| from .modules.vae2_1 import Wan2_1_VAE |
| from .utils.fm_solvers import ( |
| FlowDPMSolverMultistepScheduler, |
| get_sampling_sigmas, |
| retrieve_timesteps, |
| ) |
| from .utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
|
|
| def load_safetensors(path): |
| tensors = {} |
| with safe_open(path, framework="pt", device="cpu") as f: |
| for key in f.keys(): |
| tensors[key] = f.get_tensor(key) |
| return tensors |
|
|
|
|
| class WanS2V: |
|
|
| def __init__( |
| self, |
| config, |
| checkpoint_dir, |
| device_id=0, |
| rank=0, |
| t5_fsdp=False, |
| dit_fsdp=False, |
| use_sp=False, |
| t5_cpu=False, |
| init_on_cpu=True, |
| convert_model_dtype=False, |
| ): |
| r""" |
| Initializes the image-to-video generation model components. |
| |
| Args: |
| config (EasyDict): |
| Object containing model parameters initialized from config.py |
| checkpoint_dir (`str`): |
| Path to directory containing model checkpoints |
| device_id (`int`, *optional*, defaults to 0): |
| Id of target GPU device |
| rank (`int`, *optional*, defaults to 0): |
| Process rank for distributed training |
| t5_fsdp (`bool`, *optional*, defaults to False): |
| Enable FSDP sharding for T5 model |
| dit_fsdp (`bool`, *optional*, defaults to False): |
| Enable FSDP sharding for DiT model |
| use_sp (`bool`, *optional*, defaults to False): |
| Enable distribution strategy of sequence parallel. |
| t5_cpu (`bool`, *optional*, defaults to False): |
| Whether to place T5 model on CPU. Only works without t5_fsdp. |
| init_on_cpu (`bool`, *optional*, defaults to True): |
| Enable initializing Transformer Model on CPU. Only works without FSDP or USP. |
| convert_model_dtype (`bool`, *optional*, defaults to False): |
| Convert DiT model parameters dtype to 'config.param_dtype'. |
| Only works without FSDP. |
| """ |
| self.device = torch.device(f"cuda:{device_id}") |
| self.config = config |
| self.rank = rank |
| self.t5_cpu = t5_cpu |
| self.init_on_cpu = init_on_cpu |
|
|
| self.num_train_timesteps = config.num_train_timesteps |
| self.param_dtype = config.param_dtype |
|
|
| if t5_fsdp or dit_fsdp or use_sp: |
| self.init_on_cpu = False |
|
|
| shard_fn = partial(shard_model, device_id=device_id) |
| self.text_encoder = T5EncoderModel( |
| text_len=config.text_len, |
| dtype=config.t5_dtype, |
| device=torch.device('cpu'), |
| checkpoint_path=os.path.join(checkpoint_dir, config.t5_checkpoint), |
| tokenizer_path=os.path.join(checkpoint_dir, config.t5_tokenizer), |
| shard_fn=shard_fn if t5_fsdp else None, |
| ) |
|
|
| self.vae = Wan2_1_VAE( |
| vae_pth=os.path.join(checkpoint_dir, config.vae_checkpoint), |
| device=self.device) |
|
|
| logging.info(f"Creating WanModel from {checkpoint_dir}") |
| if not dit_fsdp: |
| self.noise_model = WanModel_S2V.from_pretrained( |
| checkpoint_dir, |
| torch_dtype=self.param_dtype, |
| device_map=self.device) |
| else: |
| self.noise_model = WanModel_S2V.from_pretrained( |
| checkpoint_dir, torch_dtype=self.param_dtype) |
|
|
| self.noise_model = self._configure_model( |
| model=self.noise_model, |
| use_sp=use_sp, |
| dit_fsdp=dit_fsdp, |
| shard_fn=shard_fn, |
| convert_model_dtype=convert_model_dtype) |
|
|
| self.audio_encoder = AudioEncoder( |
| model_id=os.path.join(checkpoint_dir, |
| "wav2vec2-large-xlsr-53-english")) |
|
|
| if use_sp: |
| self.sp_size = get_world_size() |
| else: |
| self.sp_size = 1 |
|
|
| self.sample_neg_prompt = config.sample_neg_prompt |
| self.motion_frames = config.transformer.motion_frames |
| self.drop_first_motion = config.drop_first_motion |
| self.fps = config.sample_fps |
| self.audio_sample_m = 0 |
|
|
| def _configure_model(self, model, use_sp, dit_fsdp, shard_fn, |
| convert_model_dtype): |
| """ |
| Configures a model object. This includes setting evaluation modes, |
| applying distributed parallel strategy, and handling device placement. |
| |
| Args: |
| model (torch.nn.Module): |
| The model instance to configure. |
| use_sp (`bool`): |
| Enable distribution strategy of sequence parallel. |
| dit_fsdp (`bool`): |
| Enable FSDP sharding for DiT model. |
| shard_fn (callable): |
| The function to apply FSDP sharding. |
| convert_model_dtype (`bool`): |
| Convert DiT model parameters dtype to 'config.param_dtype'. |
| Only works without FSDP. |
| |
| Returns: |
| torch.nn.Module: |
| The configured model. |
| """ |
| model.eval().requires_grad_(False) |
| if use_sp: |
| for block in model.blocks: |
| block.self_attn.forward = types.MethodType( |
| sp_attn_forward_s2v, block.self_attn) |
| model.use_context_parallel = True |
|
|
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| if dit_fsdp: |
| model = shard_fn(model) |
| else: |
| if convert_model_dtype: |
| model.to(self.param_dtype) |
| if not self.init_on_cpu: |
| model.to(self.device) |
|
|
| return model |
|
|
| def get_size_less_than_area(self, |
| height, |
| width, |
| target_area=1024 * 704, |
| divisor=64): |
| if height * width <= target_area: |
| |
| |
| max_upper_area = target_area |
| min_scale = 0.1 |
| max_scale = 1.0 |
| else: |
| |
| max_upper_area = target_area |
| d = divisor - 1 |
| b = d * (height + width) |
| a = height * width |
| c = d**2 - max_upper_area |
|
|
| |
| min_scale = (-b + math.sqrt(b**2 - 2 * a * c)) / ( |
| 2 * a) |
| max_scale = math.sqrt(max_upper_area / |
| (height * width)) |
|
|
| |
| |
| find_it = False |
| for i in range(100): |
| scale = max_scale - (max_scale - min_scale) * i / 100 |
| new_height, new_width = int(height * scale), int(width * scale) |
|
|
| |
| pad_height = (64 - new_height % 64) % 64 |
| pad_width = (64 - new_width % 64) % 64 |
| pad_top = pad_height // 2 |
| pad_bottom = pad_height - pad_top |
| pad_left = pad_width // 2 |
| pad_right = pad_width - pad_left |
|
|
| padded_height, padded_width = new_height + pad_height, new_width + pad_width |
|
|
| if padded_height * padded_width <= max_upper_area: |
| find_it = True |
| break |
|
|
| if find_it: |
| return padded_height, padded_width |
| else: |
| |
| aspect_ratio = width / height |
| target_width = int( |
| (target_area * aspect_ratio)**0.5 // divisor * divisor) |
| target_height = int( |
| (target_area / aspect_ratio)**0.5 // divisor * divisor) |
|
|
| |
| if target_width >= width or target_height >= height: |
| target_width = int(width // divisor * divisor) |
| target_height = int(height // divisor * divisor) |
|
|
| return target_height, target_width |
|
|
| def prepare_default_cond_input(self, |
| map_shape=[3, 12, 64, 64], |
| motion_frames=5, |
| lat_motion_frames=2, |
| enable_mano=False, |
| enable_kp=False, |
| enable_pose=False): |
| default_value = [1.0, -1.0, -1.0] |
| cond_enable = [enable_mano, enable_kp, enable_pose] |
| cond = [] |
| for d, c in zip(default_value, cond_enable): |
| if c: |
| map_value = torch.ones( |
| map_shape, dtype=self.param_dtype, device=self.device) * d |
| cond_lat = torch.cat([ |
| map_value[:, :, 0:1].repeat(1, 1, motion_frames, 1, 1), |
| map_value |
| ], |
| dim=2) |
| cond_lat = torch.stack( |
| self.vae.encode(cond_lat.to( |
| self.param_dtype)))[:, :, lat_motion_frames:].to( |
| self.param_dtype) |
|
|
| cond.append(cond_lat) |
| if len(cond) >= 1: |
| cond = torch.cat(cond, dim=1) |
| else: |
| cond = None |
| return cond |
|
|
| def encode_audio(self, audio_path, infer_frames): |
| z = self.audio_encoder.extract_audio_feat( |
| audio_path, return_all_layers=True) |
| audio_embed_bucket, num_repeat = self.audio_encoder.get_audio_embed_bucket_fps( |
| z, fps=self.fps, batch_frames=infer_frames, m=self.audio_sample_m) |
| audio_embed_bucket = audio_embed_bucket.to(self.device, |
| self.param_dtype) |
| audio_embed_bucket = audio_embed_bucket.unsqueeze(0) |
| if len(audio_embed_bucket.shape) == 3: |
| audio_embed_bucket = audio_embed_bucket.permute(0, 2, 1) |
| elif len(audio_embed_bucket.shape) == 4: |
| audio_embed_bucket = audio_embed_bucket.permute(0, 2, 3, 1) |
| return audio_embed_bucket, num_repeat |
|
|
| def read_last_n_frames(self, |
| video_path, |
| n_frames, |
| target_fps=16, |
| reverse=False): |
| """ |
| Read the last `n_frames` from a video at the specified frame rate. |
| |
| Parameters: |
| video_path (str): Path to the video file. |
| n_frames (int): Number of frames to read. |
| target_fps (int, optional): Target sampling frame rate. Defaults to 16. |
| reverse (bool, optional): Whether to read frames in reverse order. |
| If True, reads the first `n_frames` instead of the last ones. |
| |
| Returns: |
| np.ndarray: A NumPy array of shape [n_frames, H, W, 3], representing the sampled video frames. |
| """ |
| vr = VideoReader(video_path) |
| original_fps = vr.get_avg_fps() |
| total_frames = len(vr) |
|
|
| interval = max(1, round(original_fps / target_fps)) |
|
|
| required_span = (n_frames - 1) * interval |
|
|
| start_frame = max(0, total_frames - required_span - |
| 1) if not reverse else 0 |
|
|
| sampled_indices = [] |
| for i in range(n_frames): |
| indice = start_frame + i * interval |
| if indice >= total_frames: |
| break |
| else: |
| sampled_indices.append(indice) |
|
|
| return vr.get_batch(sampled_indices).asnumpy() |
|
|
| def load_pose_cond(self, pose_video, num_repeat, infer_frames, size): |
| HEIGHT, WIDTH = size |
| if not pose_video is None: |
| pose_seq = self.read_last_n_frames( |
| pose_video, |
| n_frames=infer_frames * num_repeat, |
| target_fps=self.fps, |
| reverse=True) |
|
|
| resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) |
| crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) |
| tensor_trans = transforms.ToTensor() |
|
|
| cond_tensor = torch.from_numpy(pose_seq) |
| cond_tensor = cond_tensor.permute(0, 3, 1, 2) / 255.0 * 2 - 1.0 |
| cond_tensor = crop_opreat(resize_opreat(cond_tensor)).permute( |
| 1, 0, 2, 3).unsqueeze(0) |
|
|
| padding_frame_num = num_repeat * infer_frames - cond_tensor.shape[2] |
| cond_tensor = torch.cat([ |
| cond_tensor, |
| - torch.ones([1, 3, padding_frame_num, HEIGHT, WIDTH]) |
| ], |
| dim=2) |
|
|
| cond_tensors = torch.chunk(cond_tensor, num_repeat, dim=2) |
| else: |
| cond_tensors = [-torch.ones([1, 3, infer_frames, HEIGHT, WIDTH])] |
|
|
| COND = [] |
| for r in range(len(cond_tensors)): |
| cond = cond_tensors[r] |
| cond = torch.cat([cond[:, :, 0:1].repeat(1, 1, 1, 1, 1), cond], |
| dim=2) |
| cond_lat = torch.stack( |
| self.vae.encode( |
| cond.to(dtype=self.param_dtype, |
| device=self.device)))[:, :, |
| 1:].cpu() |
| COND.append(cond_lat) |
| return COND |
|
|
| def get_gen_size(self, size, max_area, ref_image_path, pre_video_path): |
| if not size is None: |
| HEIGHT, WIDTH = size |
| else: |
| if pre_video_path: |
| ref_image = self.read_last_n_frames( |
| pre_video_path, n_frames=1)[0] |
| else: |
| ref_image = np.array(Image.open(ref_image_path).convert('RGB')) |
| HEIGHT, WIDTH = ref_image.shape[:2] |
| HEIGHT, WIDTH = self.get_size_less_than_area( |
| HEIGHT, WIDTH, target_area=max_area) |
| return (HEIGHT, WIDTH) |
|
|
| def generate( |
| self, |
| input_prompt, |
| ref_image_path, |
| audio_path, |
| num_repeat=1, |
| pose_video=None, |
| max_area=720 * 1280, |
| infer_frames=80, |
| shift=5.0, |
| sample_solver='unipc', |
| sampling_steps=40, |
| guide_scale=5.0, |
| n_prompt="", |
| seed=-1, |
| offload_model=True, |
| init_first_frame=False, |
| ): |
| r""" |
| Generates video frames from input image and text prompt using diffusion process. |
| |
| Args: |
| input_prompt (`str`): |
| Text prompt for content generation. |
| ref_image_path ('str'): |
| Input image path |
| audio_path ('str'): |
| Audio for video driven |
| num_repeat ('int'): |
| Number of clips to generate; will be automatically adjusted based on the audio length |
| pose_video ('str'): |
| If provided, uses a sequence of poses to drive the generated video |
| max_area (`int`, *optional*, defaults to 720*1280): |
| Maximum pixel area for latent space calculation. Controls video resolution scaling |
| infer_frames (`int`, *optional*, defaults to 80): |
| How many frames to generate per clips. The number should be 4n |
| shift (`float`, *optional*, defaults to 5.0): |
| Noise schedule shift parameter. Affects temporal dynamics |
| [NOTE]: If you want to generate a 480p video, it is recommended to set the shift value to 3.0. |
| sample_solver (`str`, *optional*, defaults to 'unipc'): |
| Solver used to sample the video. |
| sampling_steps (`int`, *optional*, defaults to 40): |
| Number of diffusion sampling steps. Higher values improve quality but slow generation |
| guide_scale (`float` or tuple[`float`], *optional*, defaults 5.0): |
| Classifier-free guidance scale. Controls prompt adherence vs. creativity. |
| If tuple, the first guide_scale will be used for low noise model and |
| the second guide_scale will be used for high noise model. |
| n_prompt (`str`, *optional*, defaults to ""): |
| Negative prompt for content exclusion. If not given, use `config.sample_neg_prompt` |
| seed (`int`, *optional*, defaults to -1): |
| Random seed for noise generation. If -1, use random seed |
| offload_model (`bool`, *optional*, defaults to True): |
| If True, offloads models to CPU during generation to save VRAM |
| init_first_frame (`bool`, *optional*, defaults to False): |
| Whether to use the reference image as the first frame (i.e., standard image-to-video generation) |
| |
| Returns: |
| torch.Tensor: |
| Generated video frames tensor. Dimensions: (C, N H, W) where: |
| - C: Color channels (3 for RGB) |
| - N: Number of frames (81) |
| - H: Frame height (from max_area) |
| - W: Frame width from max_area) |
| """ |
| |
| size = self.get_gen_size( |
| size=None, |
| max_area=max_area, |
| ref_image_path=ref_image_path, |
| pre_video_path=None) |
| HEIGHT, WIDTH = size |
| channel = 3 |
|
|
| resize_opreat = transforms.Resize(min(HEIGHT, WIDTH)) |
| crop_opreat = transforms.CenterCrop((HEIGHT, WIDTH)) |
| tensor_trans = transforms.ToTensor() |
|
|
| ref_image = None |
| motion_latents = None |
|
|
| if ref_image is None: |
| ref_image = np.array(Image.open(ref_image_path).convert('RGB')) |
| if motion_latents is None: |
| motion_latents = torch.zeros( |
| [1, channel, self.motion_frames, HEIGHT, WIDTH], |
| dtype=self.param_dtype, |
| device=self.device) |
|
|
| |
| audio_emb, nr = self.encode_audio(audio_path, infer_frames=infer_frames) |
| if num_repeat is None or num_repeat > nr: |
| num_repeat = nr |
|
|
| lat_motion_frames = (self.motion_frames + 3) // 4 |
| model_pic = crop_opreat(resize_opreat(Image.fromarray(ref_image))) |
|
|
| ref_pixel_values = tensor_trans(model_pic) |
| ref_pixel_values = ref_pixel_values.unsqueeze(1).unsqueeze( |
| 0) * 2 - 1.0 |
| ref_pixel_values = ref_pixel_values.to( |
| dtype=self.vae.dtype, device=self.vae.device) |
| ref_latents = torch.stack(self.vae.encode(ref_pixel_values)) |
|
|
| |
| videos_last_frames = motion_latents.detach() |
| drop_first_motion = self.drop_first_motion |
| if init_first_frame: |
| drop_first_motion = False |
| motion_latents[:, :, -6:] = ref_pixel_values |
| motion_latents = torch.stack(self.vae.encode(motion_latents)) |
|
|
| |
| COND = self.load_pose_cond( |
| pose_video=pose_video, |
| num_repeat=num_repeat, |
| infer_frames=infer_frames, |
| size=size) |
|
|
| seed = seed if seed >= 0 else random.randint(0, sys.maxsize) |
|
|
| if n_prompt == "": |
| n_prompt = self.sample_neg_prompt |
|
|
| |
| if not self.t5_cpu: |
| self.text_encoder.model.to(self.device) |
| context = self.text_encoder([input_prompt], self.device) |
| context_null = self.text_encoder([n_prompt], self.device) |
| if offload_model: |
| self.text_encoder.model.cpu() |
| else: |
| context = self.text_encoder([input_prompt], torch.device('cpu')) |
| context_null = self.text_encoder([n_prompt], torch.device('cpu')) |
| context = [t.to(self.device) for t in context] |
| context_null = [t.to(self.device) for t in context_null] |
|
|
| out = [] |
| |
| with ( |
| torch.amp.autocast('cuda', dtype=self.param_dtype), |
| torch.no_grad(), |
| ): |
| for r in range(num_repeat): |
| seed_g = torch.Generator(device=self.device) |
| seed_g.manual_seed(seed + r) |
|
|
| lat_target_frames = (infer_frames + 3 + self.motion_frames |
| ) // 4 - lat_motion_frames |
| target_shape = [lat_target_frames, HEIGHT // 8, WIDTH // 8] |
| noise = [ |
| torch.randn( |
| 16, |
| target_shape[0], |
| target_shape[1], |
| target_shape[2], |
| dtype=self.param_dtype, |
| device=self.device, |
| generator=seed_g) |
| ] |
| max_seq_len = np.prod(target_shape) // 4 |
|
|
| if sample_solver == 'unipc': |
| sample_scheduler = FlowUniPCMultistepScheduler( |
| num_train_timesteps=self.num_train_timesteps, |
| shift=1, |
| use_dynamic_shifting=False) |
| sample_scheduler.set_timesteps( |
| sampling_steps, device=self.device, shift=shift) |
| timesteps = sample_scheduler.timesteps |
| elif sample_solver == 'dpm++': |
| sample_scheduler = FlowDPMSolverMultistepScheduler( |
| num_train_timesteps=self.num_train_timesteps, |
| shift=1, |
| use_dynamic_shifting=False) |
| sampling_sigmas = get_sampling_sigmas(sampling_steps, shift) |
| timesteps, _ = retrieve_timesteps( |
| sample_scheduler, |
| device=self.device, |
| sigmas=sampling_sigmas) |
| else: |
| raise NotImplementedError("Unsupported solver.") |
|
|
| latents = deepcopy(noise) |
| with torch.no_grad(): |
| left_idx = r * infer_frames |
| right_idx = r * infer_frames + infer_frames |
| cond_latents = COND[r] if pose_video else COND[0] * 0 |
| cond_latents = cond_latents.to( |
| dtype=self.param_dtype, device=self.device) |
| audio_input = audio_emb[..., left_idx:right_idx] |
| input_motion_latents = motion_latents.clone() |
|
|
| arg_c = { |
| 'context': context[0:1], |
| 'seq_len': max_seq_len, |
| 'cond_states': cond_latents, |
| "motion_latents": input_motion_latents, |
| 'ref_latents': ref_latents, |
| "audio_input": audio_input, |
| "motion_frames": [self.motion_frames, lat_motion_frames], |
| "drop_motion_frames": drop_first_motion and r == 0, |
| } |
| if guide_scale > 1: |
| arg_null = { |
| 'context': context_null[0:1], |
| 'seq_len': max_seq_len, |
| 'cond_states': cond_latents, |
| "motion_latents": input_motion_latents, |
| 'ref_latents': ref_latents, |
| "audio_input": 0.0 * audio_input, |
| "motion_frames": [ |
| self.motion_frames, lat_motion_frames |
| ], |
| "drop_motion_frames": drop_first_motion and r == 0, |
| } |
| if offload_model or self.init_on_cpu: |
| self.noise_model.to(self.device) |
| torch.cuda.empty_cache() |
|
|
| for i, t in enumerate(tqdm(timesteps)): |
| latent_model_input = latents[0:1] |
| timestep = [t] |
|
|
| timestep = torch.stack(timestep).to(self.device) |
|
|
| noise_pred_cond = self.noise_model( |
| latent_model_input, t=timestep, **arg_c) |
|
|
| if guide_scale > 1: |
| noise_pred_uncond = self.noise_model( |
| latent_model_input, t=timestep, **arg_null) |
| noise_pred = [ |
| u + guide_scale * (c - u) |
| for c, u in zip(noise_pred_cond, noise_pred_uncond) |
| ] |
| else: |
| noise_pred = noise_pred_cond |
|
|
| temp_x0 = sample_scheduler.step( |
| noise_pred[0].unsqueeze(0), |
| t, |
| latents[0].unsqueeze(0), |
| return_dict=False, |
| generator=seed_g)[0] |
| latents[0] = temp_x0.squeeze(0) |
|
|
| if offload_model: |
| self.noise_model.cpu() |
| torch.cuda.synchronize() |
| torch.cuda.empty_cache() |
| latents = torch.stack(latents) |
| if not (drop_first_motion and r == 0): |
| decode_latents = torch.cat([motion_latents, latents], dim=2) |
| else: |
| decode_latents = torch.cat([ref_latents, latents], dim=2) |
| image = torch.stack(self.vae.decode(decode_latents)) |
| image = image[:, :, -(infer_frames):] |
| if (drop_first_motion and r == 0): |
| image = image[:, :, 3:] |
|
|
| overlap_frames_num = min(self.motion_frames, image.shape[2]) |
| videos_last_frames = torch.cat([ |
| videos_last_frames[:, :, overlap_frames_num:], |
| image[:, :, -overlap_frames_num:] |
| ], |
| dim=2) |
| videos_last_frames = videos_last_frames.to( |
| dtype=motion_latents.dtype, device=motion_latents.device) |
| motion_latents = torch.stack( |
| self.vae.encode(videos_last_frames)) |
| out.append(image.cpu()) |
|
|
| videos = torch.cat(out, dim=2) |
| del noise, latents |
| del sample_scheduler |
| if offload_model: |
| gc.collect() |
| torch.cuda.synchronize() |
| if dist.is_initialized(): |
| dist.barrier() |
|
|
| return videos[0] if self.rank == 0 else None |
|
|