|
|
import torch |
|
|
from diffsynth import ModelManager, WanVideoPipeline |
|
|
from PIL import Image |
|
|
import argparse |
|
|
from transformers import Wav2Vec2Processor, Wav2Vec2Model |
|
|
import librosa |
|
|
import os |
|
|
import subprocess |
|
|
import cv2 |
|
|
from model import FantasyTalkingAudioConditionModel |
|
|
from utils import save_video, get_audio_features, resize_image_by_longest_edge |
|
|
from pathlib import Path |
|
|
from datetime import datetime |
|
|
|
|
|
from huggingface_hub import snapshot_download |
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="Simple example of a training script.") |
|
|
|
|
|
parser.add_argument( |
|
|
"--wan_model_dir", |
|
|
type=str, |
|
|
default="./models/Wan2.1-I2V-14B-720P", |
|
|
required=False, |
|
|
help="The dir of the Wan I2V 14B model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fantasytalking_model_path", |
|
|
type=str, |
|
|
default="./models/fantasytalking_model.ckpt", |
|
|
required=False, |
|
|
help="The .ckpt path of fantasytalking model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--wav2vec_model_dir", |
|
|
type=str, |
|
|
default="./models/wav2vec2-base-960h", |
|
|
required=False, |
|
|
help="The dir of wav2vec model.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--image_path", |
|
|
type=str, |
|
|
default="./assets/images/woman.png", |
|
|
required=False, |
|
|
help="The path of the image.", |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
"--audio_path", |
|
|
type=str, |
|
|
default="./assets/audios/woman.wav", |
|
|
required=False, |
|
|
help="The path of the audio.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt", |
|
|
type=str, |
|
|
default="A woman is talking.", |
|
|
required=False, |
|
|
help="prompt.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output_dir", |
|
|
type=str, |
|
|
default="./output", |
|
|
help="Dir to save the model.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--image_size", |
|
|
type=int, |
|
|
default=512, |
|
|
help="The image will be resized proportionally to this size.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--audio_scale", |
|
|
type=float, |
|
|
default=1.0, |
|
|
help="Audio condition injection weight", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--prompt_cfg_scale", |
|
|
type=float, |
|
|
default=5.0, |
|
|
required=False, |
|
|
help="Prompt cfg scale", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--audio_cfg_scale", |
|
|
type=float, |
|
|
default=5.0, |
|
|
required=False, |
|
|
help="Audio cfg scale", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--max_num_frames", |
|
|
type=int, |
|
|
default=81, |
|
|
required=False, |
|
|
help="The maximum frames for generating videos, the audio part exceeding max_num_frames/fps will be truncated." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--fps", |
|
|
type=int, |
|
|
default=23, |
|
|
required=False, |
|
|
) |
|
|
parser.add_argument( |
|
|
"--num_persistent_param_in_dit", |
|
|
type=int, |
|
|
default=None, |
|
|
required=False, |
|
|
help="Maximum parameter quantity retained in video memory, small number to reduce VRAM required" |
|
|
) |
|
|
parser.add_argument( |
|
|
"--seed", |
|
|
type=int, |
|
|
default=1111, |
|
|
required=False, |
|
|
) |
|
|
args = parser.parse_args() |
|
|
return args |
|
|
|
|
|
def load_models(args): |
|
|
|
|
|
|
|
|
snapshot_download("Wan-AI/Wan2.1-I2V-14B-720P", local_dir="./models/Wan2.1-I2V-14B-720P") |
|
|
snapshot_download("facebook/wav2vec2-base-960h", local_dir="./models/wav2vec2-base-960h") |
|
|
snapshot_download("acvlab/FantasyTalking", local_dir="./models") |
|
|
|
|
|
|
|
|
model_manager = ModelManager(device="cpu") |
|
|
|
|
|
model_manager.load_models( |
|
|
[ |
|
|
[ |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00001-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00002-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00003-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00004-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00005-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00006-of-00007.safetensors", |
|
|
f"{args.wan_model_dir}/diffusion_pytorch_model-00007-of-00007.safetensors", |
|
|
], |
|
|
f"{args.wan_model_dir}/models_clip_open-clip-xlm-roberta-large-vit-huge-14.pth", |
|
|
f"{args.wan_model_dir}/models_t5_umt5-xxl-enc-bf16.pth", |
|
|
f"{args.wan_model_dir}/Wan2.1_VAE.pth", |
|
|
], |
|
|
|
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
pipe = WanVideoPipeline.from_model_manager(model_manager, torch_dtype=torch.bfloat16, device="cuda") |
|
|
|
|
|
|
|
|
fantasytalking = FantasyTalkingAudioConditionModel(pipe.dit, 768, 2048).to("cuda") |
|
|
fantasytalking.load_audio_processor(args.fantasytalking_model_path, pipe.dit) |
|
|
|
|
|
|
|
|
pipe.enable_vram_management(num_persistent_param_in_dit=args.num_persistent_param_in_dit) |
|
|
|
|
|
|
|
|
wav2vec_processor = Wav2Vec2Processor.from_pretrained(args.wav2vec_model_dir) |
|
|
wav2vec = Wav2Vec2Model.from_pretrained(args.wav2vec_model_dir).to("cuda") |
|
|
|
|
|
return pipe,fantasytalking,wav2vec_processor,wav2vec |
|
|
|
|
|
|
|
|
|
|
|
def main(args,pipe,fantasytalking,wav2vec_processor,wav2vec): |
|
|
os.makedirs(args.output_dir,exist_ok=True) |
|
|
|
|
|
duration = librosa.get_duration(filename=args.audio_path) |
|
|
num_frames = min(int(args.fps*duration//4)*4+5,args.max_num_frames) |
|
|
|
|
|
audio_wav2vec_fea = get_audio_features(wav2vec,wav2vec_processor,args.audio_path,args.fps,num_frames) |
|
|
image = resize_image_by_longest_edge(args.image_path,args.image_size) |
|
|
width, height = image.size |
|
|
|
|
|
audio_proj_fea = fantasytalking.get_proj_fea(audio_wav2vec_fea) |
|
|
pos_idx_ranges = fantasytalking.split_audio_sequence(audio_proj_fea.size(1),num_frames=num_frames) |
|
|
audio_proj_split,audio_context_lens = fantasytalking.split_tensor_with_padding(audio_proj_fea,pos_idx_ranges,expand_length=4) |
|
|
|
|
|
|
|
|
video_audio = pipe( |
|
|
prompt=args.prompt, |
|
|
negative_prompt="人物静止不动,静止,色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走", |
|
|
input_image=image, |
|
|
width=width, |
|
|
height=height, |
|
|
num_frames=num_frames, |
|
|
num_inference_steps=30, |
|
|
seed=args.seed, tiled=True, |
|
|
audio_scale=args.audio_scale, |
|
|
cfg_scale = args.prompt_cfg_scale, |
|
|
audio_cfg_scale=args.audio_cfg_scale, |
|
|
audio_proj=audio_proj_split, |
|
|
audio_context_lens=audio_context_lens, |
|
|
latents_num_frames=(num_frames-1)//4+1 |
|
|
) |
|
|
current_time = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
save_path_tmp = f"{args.output_dir}/tmp_{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" |
|
|
save_video(video_audio, save_path_tmp, fps=args.fps, quality=5) |
|
|
|
|
|
save_path = f"{args.output_dir}/{Path(args.image_path).stem}_{Path(args.audio_path).stem}_{current_time}.mp4" |
|
|
final_command = [ |
|
|
"ffmpeg", "-y", |
|
|
"-i", save_path_tmp, |
|
|
"-i", args.audio_path, |
|
|
"-c:v", "libx264", |
|
|
"-c:a", "aac", |
|
|
"-shortest", |
|
|
save_path |
|
|
] |
|
|
subprocess.run(final_command, check=True) |
|
|
os.remove(save_path_tmp) |
|
|
return save_path |
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
pipe,fantasytalking,wav2vec_processor,wav2vec = load_models(args) |
|
|
|
|
|
main(args,pipe,fantasytalking,wav2vec_processor,wav2vec) |
|
|
|