yongqiang
initialize this repo
ba96580
import os
import sys
import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from omegaconf import OmegaConf
from PIL import Image
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
from videox_fun.dist import set_multi_gpus_devices, shard_model
from videox_fun.models import (AutoencoderKL, AutoTokenizer,
Qwen3ForCausalLM, ZImageControlTransformer2DModel)
from videox_fun.models.cache_utils import get_teacache_coefficients
from videox_fun.pipeline import ZImageControlPipeline
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8,
convert_weight_dtype_wrapper)
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora
from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image,
get_video_to_video_latent,
save_videos_grid)
from loguru import logger
import onnx
import subprocess
def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx"):
"""
执行 onnxslim 命令压缩 ONNX 模型
"""
try:
# 使用完整的命令路径(如果知道的话)
cmd = ["onnxslim", input_file, output_file]
print(f"执行命令: {' '.join(cmd)}")
# 执行命令, 实时输出
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时打印输出
for line in process.stdout:
print(line, end='')
# 等待命令完成
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"命令执行失败, 错误信息:\n{stderr}")
return False
else:
print("ONNX模型压缩完成!")
return True
except FileNotFoundError:
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim")
print("安装方法: pip install onnx-simplifier")
return False
except Exception as e:
print(f"执行命令时发生错误: {e}")
return False
# GPU memory mode, which can be chosen in [model_full_load, model_full_load_and_qfloat8, model_cpu_offload, model_cpu_offload_and_qfloat8, sequential_cpu_offload].
# model_full_load means that the entire model will be moved to the GPU.
#
# model_full_load_and_qfloat8 means that the entire model will be moved to the GPU,
# and the transformer model has been quantized to float8, which can save more GPU memory.
#
# model_cpu_offload means that the entire model will be moved to the CPU after use, which can save some GPU memory.
#
# model_cpu_offload_and_qfloat8 indicates that the entire model will be moved to the CPU after use,
# and the transformer model has been quantized to float8, which can save more GPU memory.
#
# sequential_cpu_offload means that each layer of the model will be moved to the CPU after use,
# resulting in slower speeds but saving a large amount of GPU memory.
GPU_memory_mode = "model_cpu_offload"
# Multi GPUs config
# Please ensure that the product of ulysses_degree and ring_degree equals the number of GPUs used.
# For example, if you are using 8 GPUs, you can set ulysses_degree = 2 and ring_degree = 4.
# If you are using 1 GPU, you can set ulysses_degree = 1 and ring_degree = 1.
ulysses_degree = 1
ring_degree = 1
# Use FSDP to save more GPU memory in multi gpus.
fsdp_dit = False
fsdp_text_encoder = False
# Compile will give a speedup in fixed resolution and need a little GPU memory.
# The compile_dit is not compatible with the fsdp_dit and sequential_cpu_offload.
compile_dit = False
# Config and model path
config_path = "config/z_image/z_image_control.yaml"
# model path
model_name = "models/Diffusion_Transformer/Z-Image-Turbo/"
# Choose the sampler in "Flow", "Flow_Unipc", "Flow_DPM++"
sampler_name = "Flow"
# Load pretrained model if need
transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors"
vae_path = None
lora_path = None
# Other params
sample_size = [1728, 992] # H, W
# Use torch.float16 if GPU does not support torch.bfloat16
# ome graphics cards, such as v100, 2080ti, do not support torch.bfloat16
weight_dtype = torch.bfloat16
control_image = "asset/pose.jpg"
control_context_scale = 0.75
# 使用更长的neg prompt如"模糊, 突变, 变形, 失真, 画面暗, 文本字幕, 画面固定, 连环画, 漫画, 线稿, 没有主体.", 可以增加稳定性
# 在neg prompt中添加"安静, 固定"等词语可以增加动态性.
prompt = "一位年轻女子站在阳光明媚的海岸线上, 白裙在轻拂的海风中微微飘动.她拥有一头鲜艳的紫色长发, 在风中轻盈舞动, 发间系着一个精致的黑色蝴蝶结, 与身后柔和的蔚蓝天空形成鲜明对比.她面容清秀, 眉目精致, 透着一股甜美的青春气息;神情柔和, 略带羞涩, 目光静静地凝望着远方的地平线, 双手自然交叠于身前, 仿佛沉浸在思绪之中.在她身后, 是辽阔无垠、波光粼粼的大海, 阳光洒在海面上, 映出温暖的金色光晕."
# prompt = "一位身穿白色仙袍的仙人女子手持青色仙剑, 她抬头望着天空疾驰而来的雷霆, 面容严肃, 一袭紫色长发在风中飘扬, 仿佛与天地间的风雷共舞.她站立在一片古老的山巅, 背后是连绵起伏的群山和翻滚的乌云, 整个场景充满了神秘而壮丽的气息.天空中闪电划过, 照亮了她坚定的眼神和手中的仙剑, 彷佛预示着一场即将到来的大战.她的姿态优雅而坚定, 彷佛是天地间的守护者, 准备迎接任何挑战."
negative_prompt = " "
guidance_scale = 0.00
seed = 43
num_inference_steps = 9
lora_weight = 0.55
save_path = "samples/z-image-t2i-control"
device = set_multi_gpus_devices(ulysses_degree, ring_degree)
config = OmegaConf.load(config_path)
transformer = ZImageControlTransformer2DModel.from_pretrained(
model_name,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=weight_dtype,
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']),
).to(weight_dtype)
if transformer_path is not None:
print(f"From checkpoint: {transformer_path}")
if transformer_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(transformer_path)
else:
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
if False:
class DummyControlTransformerWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.control_context_scale = 0.75
def forward(
self,
latent_model_input,
timestep,
prompt_embeds,
control_context,
):
model_out = self.model(
latent_model_input,
timestep,
prompt_embeds,
control_context=control_context,
control_context_scale=self.control_context_scale,
)
return model_out
# 使用 Torch 导出 transformer onnx 模型
dummy_input = {
"latent_model_input": [torch.randn(16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32)],
"timestep": torch.tensor([0.], device="cpu", dtype=torch.float32),
"prompt_embeds": [torch.randn(512, 2560, device="cpu", dtype=torch.float32)], # TODO: 这里需要支持最大长度
"control_context": torch.randn(1, 16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32),
# "control_context_scale": 0.75,
}
# import pdb; pdb.set_trace()
transformer_warpper = DummyControlTransformerWrapper(transformer)
transformer_warpper.eval()
transformer_path = "onnx-models/trans/"
transformer_onnx_path = os.path.join(transformer_path, "z_image_control_transformer.onnx")
if not os.path.exists(transformer_path):
os.makedirs(transformer_path, exist_ok=True)
torch.onnx.export(
transformer_warpper.to(device="cpu", dtype=torch.float32),
tuple(dummy_input.values()),
transformer_onnx_path,
opset_version=17,
input_names=list(dummy_input.keys()),
output_names=["model_out"],
do_constant_folding=True,
export_params=True,
verbose=False
)
trans_onnx = onnx.load(transformer_onnx_path)
simp_onnx_data = "onnx-models/z_image_control_transformer.onnx"
onnx.save(
trans_onnx,
simp_onnx_data,
save_as_external_data=True,
all_tensors_to_one_file=True
)
logger.info("Transformer ONNX model exported, start to simplify.")
# 在 python 中执行终端指令: onnxslim vae.onnx vae_slim.onnx 实现模型简化
success = run_onnxslim(simp_onnx_data, simp_onnx_data.replace(".onnx", "_slim.onnx"))
if success:
logger.info("Transformer ONNX model exported successfully.")
else:
sys.exit(1)
exit()
"""
(Pdb) latent_model_input_list[0].shape
torch.Size([16, 1, 216, 124])
(Pdb) timestep_model_input
tensor([0.], device='cuda:0')
(Pdb) prompt_embeds_model_input[0].shape
torch.Size([165, 2560])
(Pdb) control_context.shape
torch.Size([1, 16, 1, 216, 124])
(Pdb) control_context_scale
0.75
(Pdb) model_out_list[0].shape
torch.Size([16, 1, 216, 124])
"""
# Get Vae
vae = AutoencoderKL.from_pretrained(
model_name,
subfolder="vae"
).to(weight_dtype)
if False:
class DummyVAEEncoderWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, x):
latent_dist = self.model.encode(x)[0].mode()
return latent_dist
class DummyVAEDecoderWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, latents):
image = self.model.decode(latents, return_dict=False)[0]
return image
def export_vae_onnx(vae, sample_size):
# 使用 Torch 导出 vae onnx 模型
vae.eval()
if not os.path.exists("./onnx-models"):
os.makedirs("./onnx-models", exist_ok=True)
## 导出 VAE Decoder
vae_encoder_onnx_path = "./onnx-models/vae_encoder.onnx"
dummy_input = torch.randn(1, 3, sample_size[0], sample_size[1], device="cpu", dtype=torch.float32)
vae_encode_wrapper = DummyVAEEncoderWrapper(vae)
vae_encode_wrapper.eval()
torch.onnx.export(vae_encode_wrapper.to(torch.float32), dummy_input, vae_encoder_onnx_path, opset_version=17)
# import pdb; pdb.set_trace()
onnx.checker.check_model(vae_encoder_onnx_path)
logger.info("VAE-Encoder ONNX model exported, start to simplify.")
# 在 python 中执行终端指令: onnxslim vae.onnx vae_slim.onnx 实现模型简化
success = run_onnxslim(vae_encoder_onnx_path, vae_encoder_onnx_path.replace(".onnx", "_slim.onnx"))
## 导出 VAE Decoder
vae_decoder_onnx_path = "./onnx-models/vae_decoder.onnx"
dummy_latent = torch.randn(1, vae.config.latent_channels, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32)
vae_decode_wrapper = DummyVAEDecoderWrapper(vae)
vae_decode_wrapper.eval()
torch.onnx.export(vae_decode_wrapper.to(torch.float32), dummy_latent, input_names=["latent"], output_names=["image"], f=vae_decoder_onnx_path, opset_version=17)
onnx.checker.check_model(vae_decoder_onnx_path)
logger.info("VAE-Decoder ONNX model exported, start to simplify.")
success = run_onnxslim(vae_decoder_onnx_path, vae_decoder_onnx_path.replace(".onnx", "_slim.onnx"))
if success:
logger.info("VAE ONNX model exported successfully.")
else:
sys.exit(1)
export_vae_onnx(vae, sample_size)
exit()
if vae_path is not None:
print(f"From checkpoint: {vae_path}")
if vae_path.endswith("safetensors"):
from safetensors.torch import load_file, safe_open
state_dict = load_file(vae_path)
else:
state_dict = torch.load(vae_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = vae.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
# Get tokenizer and text_encoder
tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
model_name, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=True,
)
if False:
# 目前使用 llm_build 方法进行编译, Qwen3 架构
text_encoder.eval()
text_encoder.config.use_cache = False
text_encoder.config.output_attentions = False
text_encoder.config.output_hidden_states = False
# 创建包装器
import torch.nn as nn
class Qwen3TextEncoderExporter(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
def forward(self, input_ids, attention_mask=None):
# 返回 hidden_states[-2]
outputs = self.model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
return_dict=True
)
return outputs.hidden_states[-2]
wrapped_text_encoder = Qwen3TextEncoderExporter(text_encoder)
wrapped_text_encoder.eval()
# 导出 text_encoder 的 onnx 模型
max_sequence_length = 512
text_encoder_onnx_path = "./onnx-models/text_encoder.onnx"
"""
NOTE: 注意输入 onnx 的 mask 的 size 与 input_ids 的 size 一致. 前 N 个有效输入为 True, 后面的 padding 为 False.
例如 input_ids 的 size 为 (1, 512), 实际有效输入长度为 20, 则 attention_mask 应该为:
attention_mask = [True, True, ..., True, False, False, ..., False] # 共 512 个元素, 前 20 个为 True, 后 492 个为 False
这样可以确保 ONNX 模型在推理时正确处理 padding 部分, 避免无效计算.
"""
input_ids = torch.randint(0, tokenizer.vocab_size, (1, max_sequence_length), device="cpu", dtype=torch.long)
attention_mask = torch.ones((1, max_sequence_length), device="cpu", dtype=torch.long)
# 测试前向传播
with torch.no_grad():
test_output = wrapped_text_encoder(input_ids, attention_mask)
print(f"测试输出形状: {test_output.shape}")
import pdb; pdb.set_trace()
torch.onnx.export(
wrapped_text_encoder,
(input_ids, attention_mask),
text_encoder_onnx_path,
opset_version=17,
input_names=["input_ids", "attention_mask"],
output_names=["last_hidden_state"],
do_constant_folding=True,
export_params=True,
verbose=False
)
onnx.checker.check_model(text_encoder_onnx_path)
logger.info("Text Encoder ONNX model exported successfully.")
if not os.path.exists("./onnx-models"):
os.makedirs("./onnx-models", exist_ok=True)
logger.info("Text Encoder ONNX model exported, start to simplify.")
# 在 python 中执行终端指令: onnxslim text_encoder.onnx text_encoder_slim.onnx 实现模型简化
success = run_onnxslim(text_encoder_onnx_path, text_encoder_onnx_path.replace(".onnx", "_slim.onnx"))
if success:
logger.info("Text Encoder ONNX model exported successfully.")
else:
sys.exit(1)
# Get Scheduler
Chosen_Scheduler = scheduler_dict = {
"Flow": FlowMatchEulerDiscreteScheduler,
"Flow_Unipc": FlowUniPCMultistepScheduler,
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
}[sampler_name]
scheduler = Chosen_Scheduler.from_pretrained(
model_name,
subfolder="scheduler"
)
pipeline = ZImageControlPipeline(
vae=vae,
tokenizer=tokenizer,
text_encoder=text_encoder,
transformer=transformer,
scheduler=scheduler,
)
# if ulysses_degree > 1 or ring_degree > 1:
# from functools import partial
# transformer.enable_multi_gpus_inference()
# if fsdp_dit:
# shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=list(transformer.transformer_blocks) + list(transformer.single_transformer_blocks))
# pipeline.transformer = shard_fn(pipeline.transformer)
# print("Add FSDP DIT")
# if fsdp_text_encoder:
# shard_fn = partial(shard_model, device_id=device, param_dtype=weight_dtype, module_to_wrapper=text_encoder.language_model.layers, ignored_modules=[text_encoder.language_model.embed_tokens], transformer_layer_cls_to_wrap=["MistralDecoderLayer", "PixtralTransformer"])
# text_encoder = shard_fn(text_encoder)
# print("Add FSDP TEXT ENCODER")
# if compile_dit:
# for i in range(len(pipeline.transformer.transformer_blocks)):
# pipeline.transformer.transformer_blocks[i] = torch.compile(pipeline.transformer.transformer_blocks[i])
# print("Add Compile")
if GPU_memory_mode == "sequential_cpu_offload":
pipeline.enable_sequential_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8":
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
convert_weight_dtype_wrapper(transformer, weight_dtype)
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_cpu_offload":
pipeline.enable_model_cpu_offload(device=device)
elif GPU_memory_mode == "model_full_load_and_qfloat8":
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device)
convert_weight_dtype_wrapper(transformer, weight_dtype)
pipeline.to(device=device)
else:
pipeline.to(device=device)
generator = torch.Generator(device=device).manual_seed(seed)
# if lora_path is not None:
# pipeline = merge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
with torch.no_grad():
if control_image is not None:
control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] # torch.Size([1, 3, sample_size[0], sample_size[1]])
sample = pipeline(
prompt = prompt,
negative_prompt = negative_prompt,
height = sample_size[0],
width = sample_size[1],
generator = generator,
guidance_scale = guidance_scale,
control_image = control_image,
num_inference_steps = num_inference_steps,
control_context_scale = control_context_scale,
).images
# if lora_path is not None:
# pipeline = unmerge_lora(pipeline, lora_path, lora_weight, device=device, dtype=weight_dtype)
def save_results():
if not os.path.exists(save_path):
os.makedirs(save_path, exist_ok=True)
index = len([path for path in os.listdir(save_path)]) + 1
prefix = str(index).zfill(8)
video_path = os.path.join(save_path, prefix + ".png")
image = sample[0]
image.save(video_path)
# if ulysses_degree * ring_degree > 1:
# import torch.distributed as dist
# if dist.get_rank() == 0:
# save_results()
# else:
# save_results()
save_results()