|
|
import os
|
|
|
import sys
|
|
|
|
|
|
import numpy as np
|
|
|
import torch
|
|
|
from diffusers import FlowMatchEulerDiscreteScheduler
|
|
|
from omegaconf import OmegaConf
|
|
|
from PIL import Image
|
|
|
from transformers import AutoTokenizer
|
|
|
|
|
|
current_file_path = os.path.abspath(__file__)
|
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
|
|
|
for project_root in project_roots:
|
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None
|
|
|
|
|
|
from cogvideox.models import (AutoencoderKLWan, CLIPModel, WanT5EncoderModel,
|
|
|
WanTransformer3DModel)
|
|
|
from cogvideox.pipeline import WanFunInpaintPipeline
|
|
|
from cogvideox.utils.fp8_optimization import (convert_model_weight_to_float8, replace_parameters_by_name,
|
|
|
convert_weight_dtype_wrapper)
|
|
|
from cogvideox.utils.lora_utils import merge_lora, unmerge_lora
|
|
|
from cogvideox.utils.utils import (filter_kwargs, get_image_to_video_latent,
|
|
|
save_videos_grid)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "sequential_cpu_offload"
|
|
|
|
|
|
|
|
|
config_path = "config/wan2.1/wan_civitai.yaml"
|
|
|
|
|
|
model_name = "models/Diffusion_Transformer/Wan2.1-Fun-1.3B-InP"
|
|
|
|
|
|
|
|
|
sampler_name = "Flow"
|
|
|
|
|
|
|
|
|
transformer_path = None
|
|
|
vae_path = None
|
|
|
lora_path = None
|
|
|
|
|
|
|
|
|
sample_size = [480, 832]
|
|
|
video_length = 81
|
|
|
fps = 16
|
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16
|
|
|
|
|
|
validation_image_start = "asset/1.png"
|
|
|
validation_image_end = None
|
|
|
|
|
|
|
|
|
prompt = "ไธๅชๆฃ่ค่ฒ็็ๆญฃๆๆ็่่ข๏ผๅๅจไธไธช่้็ๆฟ้ด้็ๆต
่ฒๆฒๅไธใๆฒๅ็่ตทๆฅๆ่ฝฏ่ๅฎฝๆ๏ผไธบ่ฟๅชๆดปๆณผ็็็ๆไพไบไธไธชๅฎ็พ็ไผๆฏๅฐ็นใๅจ็็ๅ้ข๏ผ้ ๅขๆๆพ็ไธไธชๆถๅญ๏ผๆถๅญไธๆ็ไธๅน
็ฒพ็พ็้ถๆก็ป๏ผ็ปไธญๆ็ป็ไธไบ็พไธฝ็้ฃๆฏๆๅบๆฏใ็ปๆกๅจๅด่ฃ
้ฅฐ็็ฒ็บข่ฒ็่ฑๆต๏ผ่ฟไบ่ฑๆตไธไป
ๅขๆทปไบๆฟ้ด็่ฒๅฝฉ๏ผ่ฟๅธฆๆฅไบไธไธ่ช็ถๅ็ๆบใๆฟ้ด้็็ฏๅ
ๆๅ่ๆธฉๆ๏ผไปๅคฉ่ฑๆฟไธ็ๅ็ฏๅ่ง่ฝ้็ๅฐ็ฏๆฃๅๅบๆฅ๏ผ่ฅ้ ๅบไธ็งๆธฉ้ฆจ่้็ๆฐๅดใๆดไธช็ฉบ้ด็ปไบบไธ็งๅฎ้ๅ่ฐ็ๆ่ง๏ผไปฟไฝๆถ้ดๅจ่ฟ้ๅๅพ็ผๆ
ข่็พๅฅฝใ"
|
|
|
negative_prompt = "่ฒ่ฐ่ณไธฝ๏ผ่ฟๆ๏ผ้ๆ๏ผ็ป่ๆจก็ณไธๆธ
๏ผๅญๅน๏ผ้ฃๆ ผ๏ผไฝๅ๏ผ็ปไฝ๏ผ็ป้ข๏ผ้ๆญข๏ผๆดไฝๅ็ฐ๏ผๆๅทฎ่ดจ้๏ผไฝ่ดจ้๏ผJPEGๅ็ผฉๆฎ็๏ผไธ้็๏ผๆฎ็ผบ็๏ผๅคไฝ็ๆๆ๏ผ็ปๅพไธๅฅฝ็ๆ้จ๏ผ็ปๅพไธๅฅฝ็่ธ้จ๏ผ็ธๅฝข็๏ผๆฏๅฎน็๏ผๅฝขๆ็ธๅฝข็่ขไฝ๏ผๆๆ่ๅ๏ผ้ๆญขไธๅจ็็ป้ข๏ผๆไนฑ็่ๆฏ๏ผไธๆก่
ฟ๏ผ่ๆฏไบบๅพๅค๏ผๅ็่ตฐ"
|
|
|
guidance_scale = 6.0
|
|
|
seed = 43
|
|
|
num_inference_steps = 50
|
|
|
lora_weight = 0.55
|
|
|
save_path = "samples/wan-videos-fun-i2v"
|
|
|
|
|
|
config = OmegaConf.load(config_path)
|
|
|
|
|
|
transformer = WanTransformer3DModel.from_pretrained(
|
|
|
os.path.join(model_name, config['transformer_additional_kwargs'].get('transformer_subpath', 'transformer')),
|
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
|
|
|
low_cpu_mem_usage=True,
|
|
|
torch_dtype=weight_dtype,
|
|
|
)
|
|
|
|
|
|
if transformer_path is not None:
|
|
|
print(f"From checkpoint: {transformer_path}")
|
|
|
if transformer_path.endswith("safetensors"):
|
|
|
from safetensors.torch import load_file, safe_open
|
|
|
state_dict = load_file(transformer_path)
|
|
|
else:
|
|
|
state_dict = torch.load(transformer_path, map_location="cpu")
|
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False)
|
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
|
|
|
|
|
|
|
|
vae = AutoencoderKLWan.from_pretrained(
|
|
|
os.path.join(model_name, config['vae_kwargs'].get('vae_subpath', 'vae')),
|
|
|
additional_kwargs=OmegaConf.to_container(config['vae_kwargs']),
|
|
|
).to(weight_dtype)
|
|
|
|
|
|
if vae_path is not None:
|
|
|
print(f"From checkpoint: {vae_path}")
|
|
|
if vae_path.endswith("safetensors"):
|
|
|
from safetensors.torch import load_file, safe_open
|
|
|
state_dict = load_file(vae_path)
|
|
|
else:
|
|
|
state_dict = torch.load(vae_path, map_location="cpu")
|
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
|
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False)
|
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
|
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
|
os.path.join(model_name, config['text_encoder_kwargs'].get('tokenizer_subpath', 'tokenizer')),
|
|
|
)
|
|
|
|
|
|
|
|
|
text_encoder = WanT5EncoderModel.from_pretrained(
|
|
|
os.path.join(model_name, config['text_encoder_kwargs'].get('text_encoder_subpath', 'text_encoder')),
|
|
|
additional_kwargs=OmegaConf.to_container(config['text_encoder_kwargs']),
|
|
|
).to(weight_dtype)
|
|
|
text_encoder = text_encoder.eval()
|
|
|
|
|
|
|
|
|
clip_image_encoder = CLIPModel.from_pretrained(
|
|
|
os.path.join(model_name, config['image_encoder_kwargs'].get('image_encoder_subpath', 'image_encoder')),
|
|
|
).to(weight_dtype)
|
|
|
clip_image_encoder = clip_image_encoder.eval()
|
|
|
|
|
|
|
|
|
Choosen_Scheduler = scheduler_dict = {
|
|
|
"Flow": FlowMatchEulerDiscreteScheduler,
|
|
|
}[sampler_name]
|
|
|
scheduler = Choosen_Scheduler(
|
|
|
**filter_kwargs(Choosen_Scheduler, OmegaConf.to_container(config['scheduler_kwargs']))
|
|
|
)
|
|
|
|
|
|
|
|
|
pipeline = WanFunInpaintPipeline(
|
|
|
transformer=transformer,
|
|
|
vae=vae,
|
|
|
tokenizer=tokenizer,
|
|
|
text_encoder=text_encoder,
|
|
|
scheduler=scheduler,
|
|
|
clip_image_encoder=clip_image_encoder
|
|
|
)
|
|
|
if GPU_memory_mode == "sequential_cpu_offload":
|
|
|
replace_parameters_by_name(transformer, ["modulation",], device="cuda")
|
|
|
transformer.freqs = transformer.freqs.to(device="cuda")
|
|
|
pipeline.enable_sequential_cpu_offload()
|
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
|
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["modulation",])
|
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype)
|
|
|
pipeline.enable_model_cpu_offload()
|
|
|
else:
|
|
|
pipeline.enable_model_cpu_offload()
|
|
|
|
|
|
generator = torch.Generator(device="cuda").manual_seed(seed)
|
|
|
|
|
|
if lora_path is not None:
|
|
|
pipeline = merge_lora(pipeline, lora_path, lora_weight)
|
|
|
|
|
|
with torch.no_grad():
|
|
|
video_length = int((video_length - 1) // vae.config.temporal_compression_ratio * vae.config.temporal_compression_ratio) + 1 if video_length != 1 else 1
|
|
|
latent_frames = (video_length - 1) // vae.config.temporal_compression_ratio + 1
|
|
|
|
|
|
input_video, input_video_mask, clip_image = get_image_to_video_latent(validation_image_start, validation_image_end, video_length=video_length, sample_size=sample_size)
|
|
|
|
|
|
sample = pipeline(
|
|
|
prompt,
|
|
|
num_frames = video_length,
|
|
|
negative_prompt = negative_prompt,
|
|
|
height = sample_size[0],
|
|
|
width = sample_size[1],
|
|
|
generator = generator,
|
|
|
guidance_scale = guidance_scale,
|
|
|
num_inference_steps = num_inference_steps,
|
|
|
|
|
|
video = input_video,
|
|
|
mask_video = input_video_mask,
|
|
|
clip_image = clip_image,
|
|
|
).videos
|
|
|
|
|
|
if lora_path is not None:
|
|
|
pipeline = unmerge_lora(pipeline, lora_path, lora_weight)
|
|
|
|
|
|
if not os.path.exists(save_path):
|
|
|
os.makedirs(save_path, exist_ok=True)
|
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1
|
|
|
prefix = str(index).zfill(8)
|
|
|
|
|
|
if video_length == 1:
|
|
|
video_path = os.path.join(save_path, prefix + ".png")
|
|
|
|
|
|
image = sample[0, :, 0]
|
|
|
image = image.transpose(0, 1).transpose(1, 2)
|
|
|
image = (image * 255).numpy().astype(np.uint8)
|
|
|
image = Image.fromarray(image)
|
|
|
image.save(video_path)
|
|
|
else:
|
|
|
video_path = os.path.join(save_path, prefix + ".mp4")
|
|
|
save_videos_grid(sample, video_path, fps=fps) |