yongqiang
initialize this repo
ba96580
import os
import sys
import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from omegaconf import OmegaConf
from PIL import Image
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
from videox_fun.dist import set_multi_gpus_devices, shard_model
from videox_fun.models import (AutoencoderKLWan, AutoencoderKLWan3_8,
AutoTokenizer, CLIPModel,
Wan2_2Transformer3DModel_Animate,
WanT5EncoderModel)
from videox_fun.models.cache_utils import get_teacache_coefficients
from videox_fun.pipeline import Wan2_2AnimatePipeline
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8,
convert_weight_dtype_wrapper,
replace_parameters_by_name)
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import (filter_kwargs, get_image,
get_image_to_video_latent,
get_video_to_video_latent,
save_videos_grid)
# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
# model_full_load means that the entire model will be moved to the GPU.
#
# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
# and the transformer model has been quantized to float8, which can save more GPU memory.
#
# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
#
# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
# and the transformer model has been quantized to float8, which can save more GPU memory.
#
# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
# resulting in slower speeds but saving a large amount of GPU memory.
GPU_memory_mode = "sequential_cpu_offload"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree = 1
ring_degree = 1
# Use FSDP to save more GPU memory in multi gpus.
fsdp_dit = False
fsdp_text_encoder = True
# Compile will give a speedup in fixed resolution and need a little GPU memory.
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
compile_dit = False
# TeaCache config
enable_teacache = True
# Recommended to be set between 0.05 and 0.30. A larger threshold can cache more steps, speeding up the inference process,
# but it may cause slight differences between the generated content and the original content.
teacache_threshold = 0.10
# The number of steps to skip TeaCache at the beginning of the inference process, which can
# reduce the impact of TeaCache on generated video quality.
num_skip_start_steps = 5
# Whether to offload TeaCache tensors to cpu to save a little bit of GPU memory.
teacache_offload = False
# Skip some cfg steps in inference
# Recommended to be set between 0.00 and 0.25
cfg_skip_ratio = 0
# Riflex config
enable_riflex = False
# Index of intrinsic frequency
riflex_k = 6
# Config and model path
config_path = "config/wan2.2/wan_civitai_animate.yaml"
# model path
model_name = "./models/Diffusion_Transformer/Wan2.2-Animate-14B/"
# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++"
sampler_name = "Flow_Unipc"
# [NOTE]: Noise schedule shift parameter. Affects temporal dynamics.
# Used when the sampler is in "Flow_Unipc", "Flow_DPM++".
shift = 5
# Load pretrained model if need
# The transformer_path is used for low noise model, the transformer_high_path is used for high noise model.
transformer_path = None
transformer_high_path = None
vae_path = None
# Load lora model if need
# The lora_path is used for low noise model, the lora_high_path is used for high noise model.
lora_path = None
lora_high_path = None
src_root_path = "asset/wan_animate/replace/process_results/"
src_pose_path = os.path.join(src_root_path, "src_pose.mp4")
src_face_path = os.path.join(src_root_path, "src_face.mp4")
src_ref_path = os.path.join(src_root_path, "src_ref.png")
src_bg_path = os.path.join(src_root_path, "src_bg.mp4")
src_mask_path = os.path.join(src_root_path, "src_mask.mp4")
# Other params
sample_size = [480, 832]
video_length = 81
fps = 16
# Use torch.float16 if GPU does not support torch.bfloat16
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
weight_dtype = torch.bfloat16
prompt = "视频中的人在做动作"
negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走"
guidance_scale = 4.0
seed = 43
num_inference_steps = 20
# The lora_weight is used for low noise model, the lora_high_weight is used for high noise model.
lora_weight = 0.55
lora_high_weight = 0.55
save_path = "samples/wan-videos-animate"
device = set_multi_gpus_devices(ulysses_degree, ring_degree)
config = OmegaConf.load(config_path)
boundary = config['transformer_additional_kwargs'].get('boundary', 0.875)
transformer = Wan2_2Transformer3DModel_Animate.from_pretrained(
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_low_noise_model_subpath', 'transformer')),
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
low_cpu_mem_usage=True,
torch_dtype=weight_dtype,
)
if config['transformer_additional_kwargs'].get('transformer_combination_type', 'single') == "moe":
transformer_2 = Wan2_2Transformer3DModel.from_pretrained(
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_high_noise_model_subpath', 'transformer')),
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
low_cpu_mem_usage=True,
torch_dtype=weight_dtype,
)
else:
transformer_2 = None
if transformer_path is not None:
print(f"From checkpoint: {transformer_path}")
if transformer_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(transformer_path)
else:
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
if transformer_2 is not None:
if transformer_high_path is not None:
print(f"From checkpoint: {transformer_high_path}")
if transformer_high_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(transformer_high_path)
else:
state_dict = torch.load(transformer_high_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer_2.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# Get Vae
Chosen_AutoencoderKL = {
"AutoencoderKLWan": AutoencoderKLWan,
"AutoencoderKLWan3_8": AutoencoderKLWan3_8
}[config['vae_kwargs'].get('vae_type', 'AutoencoderKLWan')]
vae = Chosen_AutoencoderKL.from_pretrained(
os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
).to(weight_dtype)
if vae_path is not None:
print(f"From checkpoint: {vae_path}")
if vae_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(vae_path)
else:
state_dict = torch.load(vae_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = vae.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# Get Tokenizer
tokenizer = AutoTokenizer.from_pretrained(
os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
)
# Get Text encoder
text_encoder = WanT5EncoderModel.from_pretrained(
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
low_cpu_mem_usage=True,
torch_dtype=weight_dtype,
)
text_encoder = text_encoder.eval()
# Get Clip Image Encoder
clip_image_encoder = CLIPModel.from_pretrained(
os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
).to(weight_dtype)
clip_image_encoder = clip_image_encoder.eval()
# Get Scheduler
Chosen_Scheduler = scheduler_dict = {
"Flow": FlowMatchEulerDiscreteScheduler,
"Flow_Unipc": FlowUniPCMultistepScheduler,
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
}[sampler_name]
if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++":
config['scheduler_kwargs']['shift'] = 1
scheduler = Chosen_Scheduler(
**filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
)
# Get Pipeline
pipeline = Wan2_2AnimatePipeline(
transformer=transformer,
transformer_2=transformer_2,
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
clip_image_encoder=clip_image_encoder,
scheduler=scheduler,
)
if ulysses_degree > 1 or ring_degree > 1:
from functools import partial
transformer.enable_multi_gpus_inference()
if transformer_2 is not None:
transformer_2.enable_multi_gpus_inference()
if fsdp_dit:
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype)
pipeline.transformer = shard_fn(pipeline.transformer)
if transformer_2 is not None:
pipeline.transformer_2 = shard_fn(pipeline.transformer_2)
print("Add FSDP DIT")
if fsdp_text_encoder:
shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype)
pipeline.text_encoder = shard_fn(pipeline.text_encoder)
print("Add FSDP TEXT ENCODER")
if compile_dit:
for i in range(len(pipeline.transformer.blocks)):
pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i])
if transformer_2 is not None:
for i in range(len(pipeline.transformer_2.blocks)):
pipeline.transformer_2.blocks[i] = torch.compile(pipeline.transformer_2.blocks[i])
print("Add Compile")
if GPU_memory_mode == "sequential_cpu_offload":
replace_parameters_by_name(transformer, ["modulation",], device=device)
transformer.freqs = transformer.freqs.to(device=device)
if transformer_2 is not None:
replace_parameters_by_name(transformer_2, ["modulation",], device=device)
transformer_2.freqs = transformer_2.freqs.to(device=device)
pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
convert_weight_dtype_wrapper(transformer, weight_dtype)
if transformer_2 is not None:
convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device)
convert_weight_dtype_wrapper(transformer_2, weight_dtype)
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_full_load_and_qfloat8":
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device)
convert_weight_dtype_wrapper(transformer, weight_dtype)
if transformer_2 is not None:
convert_model_weight_to_float8(transformer_2, exclude_module_name=["modulation",], device=device)
convert_weight_dtype_wrapper(transformer_2, weight_dtype)
pipeline.to(device=device)
else:
pipeline.to(device=device)
coefficients = get_teacache_coefficients(model_name) if enable_teacache else None
if coefficients is not None:
print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.")
pipeline.transformer.enable_teacache(
coefficients, num_inference_steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload
)
if transformer_2 is not None:
pipeline.transformer_2.share_teacache(transformer=pipeline.transformer)
if cfg_skip_ratio is not None:
print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.")
pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, num_inference_steps)
if transformer_2 is not None:
pipeline.transformer_2.share_cfg_skip(transformer=pipeline.transformer)
generator = torch.Generator(device=device).manual_seed(seed)
if lora_path is not None:
pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
if transformer_2 is not None:
pipeline = merge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
with torch.no_grad():
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
if enable_riflex:
pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames)
if transformer_2 is not None:
pipeline.transformer_2.enable_riflex(k = riflex_k, L_test = latent_frames)
pose_video, _, _, _ = get_video_to_video_latent(src_pose_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None)
face_video, _, _, _ = get_video_to_video_latent(src_face_path, video_length=video_length, sample_size=[512, 512], fps=fps, ref_image=None)
ref_image = get_image(src_ref_path)
if os.path.exists(src_bg_path):
bg_video, _, _, _ = get_video_to_video_latent(src_bg_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None)
mask_video, _, _, _ = get_video_to_video_latent(src_mask_path, video_length=video_length, sample_size=sample_size, fps=fps, ref_image=None)
mask_video = mask_video[:, :1]
replace_flag = True
else:
bg_video = None
mask_video = None
replace_flag = False
sample = pipeline(
prompt,
num_frames = video_length,
negative_prompt = negative_prompt,
height = sample_size[0],
width = sample_size[1],
generator = generator,
guidance_scale = guidance_scale,
num_inference_steps = num_inference_steps,
boundary = boundary,
pose_video = pose_video,
face_video = face_video,
ref_image = ref_image,
bg_video = bg_video,
mask_video = mask_video,
replace_flag = replace_flag,
shift = shift,
).videos
if lora_path is not None:
pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
if transformer_2 is not None:
pipeline = unmerge_lora(pipeline, lora_high_path, lora_high_weight, device=device, dtype=weight_dtype, sub_transformer_name="transformer_2")
def save_results():
if not os.path.exists(save_path):
os.makedirs(save_path, exist_ok=True)
index = len([path for path in os.listdir(save_path)]) + 1
prefix = str(index).zfill(8)
if video_length == 1:
video_path = os.path.join(save_path, prefix + ".png")
image = sample[0, :, 0]
image = image.transpose(0, 1).transpose(1, 2)
image = (image * 255).numpy().astype(np.uint8)
image = Image.fromarray(image)
image.save(video_path)
else:
video_path = os.path.join(save_path, prefix + ".mp4")
save_videos_grid(sample, video_path, fps=fps)
if ulysses_degree * ring_degree > 1:
import torch.distributed as dist
if dist.get_rank() == 0:
save_results()
else:
save_results()