import os import sys import numpy as np import torch from diffusers import FlowMatchEulerDiscreteScheduler from omegaconf import OmegaConf from PIL import Image from transformers import AutoTokenizer current_file_path = os.path.abspath(__file__) project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] for project_root in project_roots: sys.path.insert(0, project_root) if project_root not in sys.path else None from videox_fun.dist import set_multi_gpus_devices, shard_model from videox_fun.models import (AutoencoderKLWan, AutoTokenizer, WanT5EncoderModel, WanTransformer3DModel) from videox_fun.models.cache_utils import get_teacache_coefficients from videox_fun.pipeline import WanFunPhantomPipeline from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, convert_weight_dtype_wrapper, replace_parameters_by_name) from videox_fun.utils.lora_utils import merge_lora, unmerge_lora from videox_fun.utils.utils import (filter_kwargs, get_image_latent, save_videos_grid) from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler # GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload]. # model_full_load means that the entire model will be moved to the GPU. # # model_full_load_and_qfloat8 means that the entire model will be moved to the GPU, # and the transformer model has been quantized to float8, which can save more GPU memory. # # model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory. # # model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use, # and the transformer model has been quantized to float8, which can save more GPU memory. # # sequential_cpu_offload means that each layer of the model will be moved to the CPU after use, # resulting in slower speeds but saving a large amount of GPU memory. GPU_memory_mode = "sequential_cpu_offload" # Multi GPUs config # Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used. # For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4. # If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1. ulysses_degree = 1 ring_degree = 1 # Use FSDP to save more GPU memory in multi gpus. fsdp_dit = False fsdp_text_encoder = True # Compile will give a speedup in fixed resolution and need a little GPU memory. # The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload. compile_dit = False # Support TeaCache. enable_teacache = True # Recommended to be set between 0.05 and 0.30. A larger threshold can cache more steps, speeding up the inference process, # but it may cause slight differences between the generated content and the original content. # # --------------------------------------------------------------------------------------------------- # # | Model Name | threshold | Model Name | threshold | Model Name | threshold | # | Wan2.1-T2V-1.3B | 0.05~0.10 | Wan2.1-T2V-14B | 0.10~0.15 | Wan2.1-I2V-14B-720P | 0.20~0.30 | # | Wan2.1-I2V-14B-480P | 0.20~0.25 | Wan2.1-Fun-*-1.3B-* | 0.05~0.10 | Wan2.1-Fun-*-14B-* | 0.20~0.30 | # # --------------------------------------------------------------------------------------------------- # teacache_threshold = 0.10 # The number of steps to skip TeaCache at the beginning of the inference process, which can # reduce the impact of TeaCache on generated video quality. num_skip_start_steps = 5 # Whether to offload TeaCache tensors to cpu to save a little bit of GPU memory. teacache_offload = False # Skip some cfg steps in inference # Recommended to be set between 0.00 and 0.25 cfg_skip_ratio = 0 # Riflex config enable_riflex = False # Index of intrinsic frequency riflex_k = 6 # Config and model path config_path = "config/wan2.1/wan_civitai.yaml" # model path model_name = "models/Diffusion_Transformer/Wan2.1-T2V-1.3B" # Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++" sampler_name = "Flow_Unipc" # [NOTE]: Noise schedule shift parameter. Affects temporal dynamics. # Used when the sampler is in "Flow_Unipc", "Flow_DPM++". # If you want to generate a 480p video, it is recommended to set the shift value to 3.0. # If you want to generate a 720p video, it is recommended to set the shift value to 5.0. shift = 3 # Load pretrained model if need transformer_path = "models/Personalized_Model/Phantom-Wan-1.3B.safetensors" vae_path = None lora_path = None # Other params sample_size = [480, 832] video_length = 81 fps = 16 # Use torch.float16 if GPU does not support torch.bfloat16 # ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16 weight_dtype = torch.bfloat16 subject_ref_images = ["asset/ref_1.png", "asset/ref_2.png"] # 使用更长的neg prompt如"模糊,突变,变形,失真,画面暗,文本字幕,画面固定,连环画,漫画,线稿,没有主体。",可以增加稳定性 # 在neg prompt中添加"安静,固定"等词语可以增加动态性。 prompt = "暖阳漫过草地,扎着双马尾、头戴绿色蝴蝶结、身穿浅绿色连衣裙的小女孩蹲在盛开的雏菊旁。她身旁一只棕白相间的狗狗吐着舌头,毛茸茸尾巴欢快摇晃。小女孩笑着举起黄红配色、带有蓝色按钮的玩具相机,将和狗狗的欢乐瞬间定格。" negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" # Using longer neg prompt such as "Blurring, mutation, deformation, distortion, dark and solid, comics, text subtitles, line art." can increase stability # Adding words such as "quiet, solid" to the neg prompt can increase dynamism. # prompt = "A young woman with beautiful, clear eyes and blonde hair stands in the forest, wearing a white dress and a crown. Her expression is serene, reminiscent of a movie star, with fair and youthful skin. Her brown long hair flows in the wind. The video quality is very high, with a clear view. High quality, masterpiece, best quality, high resolution, ultra-fine, fantastical." # negative_prompt = "Twisted body, limb deformities, text captions, comic, static, ugly, error, messy code." guidance_scale = 6.0 seed = 43 num_inference_steps = 50 lora_weight = 0.55 save_path = "samples/wan-videos-phantom" device = set_multi_gpus_devices(ulysses_degree, ring_degree) config = OmegaConf.load(config_path) transformer = WanTransformer3DModel.from_pretrained( os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')), transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), low_cpu_mem_usage=True, torch_dtype=weight_dtype, ) if transformer_path is not None: print(f"From checkpoint: {transformer_path}") if transformer_path.endswith("safetensors"): from safetensors.torch import load_file, safe_open state_dict = load_file(transformer_path) else: state_dict = torch.load(transformer_path, map_location="cpu") state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict m, u = transformer.load_state_dict(state_dict, strict=False) print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") # Get Vae vae = AutoencoderKLWan.from_pretrained( os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')), additional_kwargs=OmegaConf.to_container(config['vae_kwargs']), ).to(weight_dtype) if vae_path is not None: print(f"From checkpoint: {vae_path}") if vae_path.endswith("safetensors"): from safetensors.torch import load_file, safe_open state_dict = load_file(vae_path) else: state_dict = torch.load(vae_path, map_location="cpu") state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict m, u = vae.load_state_dict(state_dict, strict=False) print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") # Get Tokenizer tokenizer = AutoTokenizer.from_pretrained( os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')), ) # Get Text encoder text_encoder = WanT5EncoderModel.from_pretrained( os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')), additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']), low_cpu_mem_usage=True, torch_dtype=weight_dtype, ) text_encoder = text_encoder.eval() # Get Scheduler Chosen_Scheduler = scheduler_dict = { "Flow": FlowMatchEulerDiscreteScheduler, "Flow_Unipc": FlowUniPCMultistepScheduler, "Flow_DPM++": FlowDPMSolverMultistepScheduler, }[sampler_name] if sampler_name == "Flow_Unipc" or sampler_name == "Flow_DPM++": config['scheduler_kwargs']['shift'] = 1 scheduler = Chosen_Scheduler( **filter_kwargs(Chosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs'])) ) # Get Pipeline pipeline = WanFunPhantomPipeline( transformer=transformer, vae=vae, tokenizer=tokenizer, text_encoder=text_encoder, scheduler=scheduler ) if ulysses_degree > 1 or ring_degree > 1: from functools import partial transformer.enable_multi_gpus_inference() if fsdp_dit: shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) pipeline.transformer = shard_fn(pipeline.transformer) print("Add FSDP DIT") if fsdp_text_encoder: shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype) pipeline.text_encoder = shard_fn(pipeline.text_encoder) print("Add FSDP TEXT ENCODER") if compile_dit: for i in range(len(pipeline.transformer.blocks)): pipeline.transformer.blocks[i] = torch.compile(pipeline.transformer.blocks[i]) print("Add Compile") if GPU_memory_mode == "sequential_cpu_offload": replace_parameters_by_name(transformer, ["modulation",], device=device) transformer.freqs = transformer.freqs.to(device=device) pipeline.enable_sequential_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer, weight_dtype) pipeline.enable_model_cpu_offload(device=device) elif GPU_memory_mode == "model_cpu_offload": pipeline.enable_model_cpu_offload(device=device) elif GPU_memory_mode == "model_full_load_and_qfloat8": convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",], device=device) convert_weight_dtype_wrapper(transformer, weight_dtype) pipeline.to(device=device) else: pipeline.to(device=device) coefficients = get_teacache_coefficients(model_name) if enable_teacache else None if coefficients is not None: print(f"Enable TeaCache with threshold {teacache_threshold} and skip the first {num_skip_start_steps} steps.") pipeline.transformer.enable_teacache( coefficients, num_inference_steps, teacache_threshold, num_skip_start_steps=num_skip_start_steps, offload=teacache_offload ) if cfg_skip_ratio is not None: print(f"Enable cfg_skip_ratio {cfg_skip_ratio}.") pipeline.transformer.enable_cfg_skip(cfg_skip_ratio, num_inference_steps) generator = torch.Generator(device=device).manual_seed(seed) if lora_path is not None: pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) with torch.no_grad(): video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1 latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1 if enable_riflex: pipeline.transformer.enable_riflex(k = riflex_k, L_test = latent_frames) if subject_ref_images is not None: subject_ref_images = [get_image_latent(_subject_ref_image, sample_size=sample_size, padding=True) for _subject_ref_image in subject_ref_images] subject_ref_images = torch.cat(subject_ref_images, dim=2) sample = pipeline( prompt, num_frames = video_length, negative_prompt = negative_prompt, height = sample_size[0], width = sample_size[1], generator = generator, guidance_scale = guidance_scale, num_inference_steps = num_inference_steps, subject_ref_images = subject_ref_images, shift = shift, ).videos if lora_path is not None: pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype) def save_results(): if not os.path.exists(save_path): os.makedirs(save_path, exist_ok=True) index = len([path for path in os.listdir(save_path)]) + 1 prefix = str(index).zfill(8) if video_length == 1: video_path = os.path.join(save_path, prefix + ".png") image = sample[0, :, 0] image = image.transpose(0, 1).transpose(1, 2) image = (image * 255).numpy().astype(np.uint8) image = Image.fromarray(image) image.save(video_path) else: video_path = os.path.join(save_path, prefix + ".mp4") save_videos_grid(sample, video_path, fps=fps) if ulysses_degree * ring_degree > 1: import torch.distributed as dist if dist.get_rank() == 0: save_results() else: save_results()