|
|
import os |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
|
|
for project_root in project_roots: |
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
|
|
|
from videox_fun.dist import set_multi_gpus_devices, shard_model |
|
|
from videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8, |
|
|
AutoTokenizer, CLIPModel, |
|
|
Wan2_2Transformer3DModel_Animate, |
|
|
WanT5EncoderModel) |
|
|
from videox_fun.models.cache_utils import get_teacache_coefficients |
|
|
from videox_fun.pipeline import Wan2_2AnimatePipeline |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, |
|
|
convert_weight_dtype_wrapper, |
|
|
replace_parameters_by_name) |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import (filter_kwargs, get_image, |
|
|
get_image_to_video_latent, |
|
|
get_video_to_video_latent, |
|
|
save_videos_grid) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "sequential_cpu_offload" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_degree = 1 |
|
|
ring_degree = 1 |
|
|
|
|
|
fsdp_dit = False |
|
|
fsdp_text_encoder = True |
|
|
|
|
|
|
|
|
compile_dit = False |
|
|
|
|
|
|
|
|
enable_teacache = True |
|
|
|
|
|
|
|
|
teacache_threshold = 0.10 |
|
|
|
|
|
|
|
|
num_skip_start_steps = 5 |
|
|
|
|
|
teacache_offload = False |
|
|
|
|
|
|
|
|
|
|
|
cfg_skip_ratio = 0 |
|
|
|
|
|
|
|
|
enable_riflex = False |
|
|
|
|
|
riflex_k = 6 |
|
|
|
|
|
|
|
|
config_path = "config/wan2.2/wan_civitai_animate.yaml" |
|
|
|
|
|
model_name = "./models/Diffusion_Transformer/Wan2.2-Animate-14B/" |
|
|
|
|
|
|
|
|
sampler_name = "Flow_Unipc" |
|
|
|
|
|
|
|
|
shift = 5 |
|
|
|
|
|
|
|
|
|
|
|
transformer_path = None |
|
|
transformer_high_path = None |
|
|
vae_path = None |
|
|
|
|
|
|
|
|
lora_path = None |
|
|
lora_high_path = None |
|
|
|
|
|
src_root_path = "asset/wan_animate/replace/process_results/" |
|
|
src_pose_path = os.path.join(src_root_path, "src_pose.mp4") |
|
|
src_face_path = os.path.join(src_root_path, "src_face.mp4") |
|
|
src_ref_path = os.path.join(src_root_path, "src_ref.png") |
|
|
src_bg_path = os.path.join(src_root_path, "src_bg.mp4") |
|
|
src_mask_path = os.path.join(src_root_path, "src_mask.mp4") |
|
|
|
|
|
|
|
|
sample_size = [480, 832] |
|
|
video_length = 81 |
|
|
fps = 16 |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
prompt = "视频中的人在做动作" |
|
|
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" |
|
|
guidance_scale = 4.0 |
|
|
seed = 43 |
|
|
num_inference_steps = 20 |
|
|
|
|
|
lora_weight = 0.55 |
|
|
lora_high_weight = 0.55 |
|
|
save_path = "samples/wan-videos-animate" |
|
|
|
|
|
device = set_multi_gpus_devices(ulysses_degree, ring_degree) |
|
|
config = OmegaConf.load(config_path) |
|
|
boundary = config['transformer_additional_kwargs'].get('boundary', 0.875) |
|
|
|
|
|
transformer = Wan2_2Transformer3DModel_Animate.from_pretrained( |
|
|
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')), |
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
) |
|
|
if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe": |
|
|
transformer_2 = Wan2_2Transformer3DModel.from_pretrained( |
|
|
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')), |
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
) |
|
|
else: |
|
|
transformer_2 = None |
|
|
|
|
|
if transformer_path is not None: |
|
|
print(f"From checkpoint: {transformer_path}") |
|
|
if transformer_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
if transformer_2 is not None: |
|
|
if transformer_high_path is not None: |
|
|
print(f"From checkpoint: {transformer_high_path}") |
|
|
if transformer_high_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_high_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_high_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer_2.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
Chosen_AutoencoderKL = { |
|
|
"AutoencoderKLWan": AutoencoderKLWan, |
|
|
"AutoencoderKLWan3_8": AutoencoderKLWan3_8 |
|
|
}[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')] |
|
|
vae = Chosen_AutoencoderKL.from_pretrained( |
|
|
os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), |
|
|
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), |
|
|
).to(weight_dtype) |
|
|
|
|
|
if vae_path is not None: |
|
|
print(f"From checkpoint: {vae_path}") |
|
|
if vae_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(vae_path) |
|
|
else: |
|
|
state_dict = torch.load(vae_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), |
|
|
) |
|
|
|
|
|
|
|
|
text_encoder = WanT5EncoderModel.from_pretrained( |
|
|
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), |
|
|
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
) |
|
|
text_encoder = text_encoder.eval() |
|
|
|
|
|
|
|
|
clip_image_encoder = CLIPModel.from_pretrained( |
|
|
os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')), |
|
|
).to(weight_dtype) |
|
|
clip_image_encoder = clip_image_encoder.eval() |
|
|
|
|
|
|
|
|
Chosen_Scheduler = scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
}[sampler_name] |
|
|
if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++": |
|
|
config['scheduler_kwargs']['shift'] = 1 |
|
|
scheduler = Chosen_Scheduler( |
|
|
**filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) |
|
|
) |
|
|
|
|
|
|
|
|
pipeline = Wan2_2AnimatePipeline( |
|
|
transformer=transformer, |
|
|
transformer_2=transformer_2, |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
clip_image_encoder=clip_image_encoder, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
if ulysses_degree > 1 or ring_degree > 1: |
|
|
from functools import partial |
|
|
transformer.enable_multi_gpus_inference() |
|
|
if transformer_2 is not None: |
|
|
transformer_2.enable_multi_gpus_inference() |
|
|
if fsdp_dit: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) |
|
|
pipeline.transformer = shard_fn(pipeline.transformer) |
|
|
if transformer_2 is not None: |
|
|
pipeline.transformer_2 = shard_fn(pipeline.transformer_2) |
|
|
print("Add FSDP DIT") |
|
|
if fsdp_text_encoder: |
|
|
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) |
|
|
pipeline.text_encoder = shard_fn(pipeline.text_encoder) |
|
|
print("Add FSDP TEXT ENCODER") |
|
|
|
|
|
if compile_dit: |
|
|
for i in range(len(pipeline.transformer.blocks)): |
|
|
pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i]) |
|
|
if transformer_2 is not None: |
|
|
for i in range(len(pipeline.transformer_2.blocks)): |
|
|
pipeline.transformer_2.blocks[i] = torch.compile(pipeline.transformer_2.blocks[i]) |
|
|
print("Add Compile") |
|
|
|
|
|
if GPU_memory_mode == "sequential_cpu_offload": |
|
|
replace_parameters_by_name(transformer, ["modulation",], device=device) |
|
|
transformer.freqs = transformer.freqs.to(device=device) |
|
|
if transformer_2 is not None: |
|
|
replace_parameters_by_name(transformer_2, ["modulation",], device=device) |
|
|
transformer_2.freqs = transformer_2.freqs.to(device=device) |
|
|
pipeline.enable_sequential_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
if transformer_2 is not None: |
|
|
convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) |
|
|
convert_weight_dtype_wrapper(transformer_2, weight_dtype) |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload": |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_full_load_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
if transformer_2 is not None: |
|
|
convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device) |
|
|
convert_weight_dtype_wrapper(transformer_2, weight_dtype) |
|
|
pipeline.to(device=device) |
|
|
else: |
|
|
pipeline.to(device=device) |
|
|
|
|
|
coefficients = get_teacache_coefficients(model_name) if enable_teacache else None |
|
|
if coefficients is not None: |
|
|
print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") |
|
|
pipeline.transformer.enable_teacache( |
|
|
coefficients, num_inference_steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload |
|
|
) |
|
|
if transformer_2 is not None: |
|
|
pipeline.transformer_2.share_teacache(transformer=pipeline.transformer) |
|
|
|
|
|
if cfg_skip_ratio is not None: |
|
|
print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") |
|
|
pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, num_inference_steps) |
|
|
if transformer_2 is not None: |
|
|
pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer) |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
if transformer_2 is not None: |
|
|
pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
|
|
|
|
|
with torch.no_grad(): |
|
|
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 |
|
|
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 |
|
|
|
|
|
if enable_riflex: |
|
|
pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) |
|
|
if transformer_2 is not None: |
|
|
pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames) |
|
|
|
|
|
pose_video, _, _, _ = get_video_to_video_latent(src_pose_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) |
|
|
|
|
|
face_video, _, _, _ = get_video_to_video_latent(src_face_path, video_length=video_length, sample_size=[512, 512], fps=fps, ref_image=None) |
|
|
|
|
|
ref_image = get_image(src_ref_path) |
|
|
|
|
|
if os.path.exists(src_bg_path): |
|
|
bg_video, _, _, _ = get_video_to_video_latent(src_bg_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) |
|
|
mask_video, _, _, _ = get_video_to_video_latent(src_mask_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None) |
|
|
mask_video = mask_video[:, :1] |
|
|
replace_flag = True |
|
|
else: |
|
|
bg_video = None |
|
|
mask_video = None |
|
|
replace_flag = False |
|
|
|
|
|
sample = pipeline( |
|
|
prompt, |
|
|
num_frames = video_length, |
|
|
negative_prompt = negative_prompt, |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
num_inference_steps = num_inference_steps, |
|
|
boundary = boundary, |
|
|
pose_video = pose_video, |
|
|
face_video = face_video, |
|
|
ref_image = ref_image, |
|
|
bg_video = bg_video, |
|
|
mask_video = mask_video, |
|
|
replace_flag = replace_flag, |
|
|
shift = shift, |
|
|
).videos |
|
|
|
|
|
if lora_path is not None: |
|
|
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) |
|
|
if transformer_2 is not None: |
|
|
pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2") |
|
|
|
|
|
def save_results(): |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
if video_length == 1: |
|
|
video_path = os.path.join(save_path, prefix + ".png") |
|
|
|
|
|
image = sample[0, :, 0] |
|
|
image = image.transpose(0, 1).transpose(1, 2) |
|
|
image = (image * 255).numpy().astype(np.uint8) |
|
|
image = Image.fromarray(image) |
|
|
image.save(video_path) |
|
|
else: |
|
|
video_path = os.path.join(save_path, prefix + ".mp4") |
|
|
save_videos_grid(sample, video_path, fps=fps) |
|
|
|
|
|
if ulysses_degree * ring_degree > 1: |
|
|
import torch.distributed as dist |
|
|
if dist.get_rank() == 0: |
|
|
save_results() |
|
|
else: |
|
|
save_results() |
|
|
|