Spaces:
Runtime error
Runtime error
| import torch | |
| from diffsynth import ModelManager, WanVideoPipeline | |
| from PIL import Image | |
| import argparse | |
| from transformers import Wav2Vec2Processor, Wav2Vec2Model | |
| import librosa | |
| import os | |
| import subprocess | |
| import cv2 | |
| from model import FantasyTalkingAudioConditionModel | |
| from utils import save_video, get_audio_features, resize_image_by_longest_edge | |
| from pathlib import Path | |
| from datetime import datetime | |
| # from modelscope import snapshot_download | |
| from huggingface_hub import snapshot_download | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Simple example of a training script.") | |
| parser.add_argument( | |
| "--wan_model_dir", | |
| type=str, | |
| default="./models/Wan2.1-I2V-14B-720P", | |
| required=False, | |
| help="The dir of the Wan I2V 14B model.", | |
| ) | |
| parser.add_argument( | |
| "--fantasytalking_model_path", | |
| type=str, | |
| default="./models/fantasytalking_model.ckpt", | |
| required=False, | |
| help="The .ckpt path of fantasytalking model.", | |
| ) | |
| parser.add_argument( | |
| "--wav2vec_model_dir", | |
| type=str, | |
| default="./models/wav2vec2-base-960h", | |
| required=False, | |
| help="The dir of wav2vec model.", | |
| ) | |
| parser.add_argument( | |
| "--image_path", | |
| type=str, | |
| default="./assets/images/woman.png", | |
| required=False, | |
| help="The path of the image.", | |
| ) | |
| parser.add_argument( | |
| "--audio_path", | |
| type=str, | |
| default="./assets/audios/woman.wav", | |
| required=False, | |
| help="The path of the audio.", | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| default="A woman is talking.", | |
| required=False, | |
| help="prompt.", | |
| ) | |
| parser.add_argument( | |
| "--output_dir", | |
| type=str, | |
| default="./output", | |
| help="Dir to save the model.", | |
| ) | |
| parser.add_argument( | |
| "--image_size", | |
| type=int, | |
| default=512, | |
| help="The image will be resized proportionally to this size.", | |
| ) | |
| parser.add_argument( | |
| "--audio_scale", | |
| type=float, | |
| default=1.0, | |
| help="Audio condition injection weight", | |
| ) | |
| parser.add_argument( | |
| "--prompt_cfg_scale", | |
| type=float, | |
| default=5.0, | |
| required=False, | |
| help="Prompt cfg scale", | |
| ) | |
| parser.add_argument( | |
| "--audio_cfg_scale", | |
| type=float, | |
| default=5.0, | |
| required=False, | |
| help="Audio cfg scale", | |
| ) | |
| parser.add_argument( | |
| "--max_num_frames", | |
| type=int, | |
| default=81, | |
| required=False, | |
| help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated." | |
| ) | |
| parser.add_argument( | |
| "--fps", | |
| type=int, | |
| default=23, | |
| required=False, | |
| ) | |
| parser.add_argument( | |
| "--num_persistent_param_in_dit", | |
| type=int, | |
| default=None, | |
| required=False, | |
| help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required" | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=1111, | |
| required=False, | |
| ) | |
| args = parser.parse_args() | |
| return args | |
| def load_models(args): | |
| # Load Wan I2V models | |
| snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P") | |
| snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h") | |
| snapshot_download("acvlab/FantasyTalking", local_dir="./models") | |
| model_manager = ModelManager(device="cpu") | |
| model_manager.load_models( | |
| [ | |
| [ | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors", | |
| f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors", | |
| ], | |
| f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", | |
| f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth", | |
| f"{args.wan_model_dir}/Wan2.1_VAE.pth", | |
| ], | |
| # torch_dtype=torch.float8_e4m3fn, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. | |
| torch_dtype=torch.bfloat16, # You can set `torch_dtype=torch.bfloat16` to disable FP8 quantization. | |
| ) | |
| pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") | |
| # Load FantasyTalking weights | |
| fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda") | |
| fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit) | |
| # You can set `num_persistent_param_in_dit` to a small number to reduce VRAM required. | |
| pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit) | |
| # Load wav2vec models | |
| wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir) | |
| wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda") | |
| return pipe,fantasytalking,wav2vec_processor,wav2vec | |
| def main(args,pipe,fantasytalking,wav2vec_processor,wav2vec): | |
| os.makedirs(args.output_dir,exist_ok=True) | |
| duration = librosa.get_duration(filename=args.audio_path) | |
| num_frames = min(int(args.fps*duration//4)*4+5,args.max_num_frames) | |
| audio_wav2vec_fea = get_audio_features(wav2vec,wav2vec_processor,args.audio_path,args.fps,num_frames) | |
| image = resize_image_by_longest_edge(args.image_path,args.image_size) | |
| width, height = image.size | |
| audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea) | |
| pos_idx_ranges = fantasytalking.split_audio_sequence(audio_proj_fea.size(1),num_frames=num_frames) | |
| audio_proj_split,audio_context_lens = fantasytalking.split_tensor_with_padding(audio_proj_fea,pos_idx_ranges,expand_length=4) # [b,21,9+8,768] | |
| # Image-to-video | |
| video_audio = pipe( | |
| prompt=args.prompt, | |
| negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", | |
| input_image=image, | |
| width=width, | |
| height=height, | |
| num_frames=num_frames, | |
| num_inference_steps=30, | |
| seed=args.seed, tiled=True, | |
| audio_scale=args.audio_scale, | |
| cfg_scale = args.prompt_cfg_scale, | |
| audio_cfg_scale=args.audio_cfg_scale, | |
| audio_proj=audio_proj_split, | |
| audio_context_lens=audio_context_lens, | |
| latents_num_frames=(num_frames-1)//4+1 | |
| ) | |
| current_time = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" | |
| save_video(video_audio, save_path_tmp, fps=args.fps, quality=5) | |
| save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" | |
| final_command = [ | |
| "ffmpeg", "-y", | |
| "-i", save_path_tmp, | |
| "-i", args.audio_path, | |
| "-c:v", "libx264", | |
| "-c:a", "aac", | |
| "-shortest", | |
| save_path | |
| ] | |
| subprocess.run(final_command, check=True) | |
| os.remove(save_path_tmp) | |
| return save_path | |
| if __name__ == "__main__": | |
| args = parse_args() | |
| pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args) | |
| main(args,pipe,fantasytalking,wav2vec_processor,wav2vec) | |