|
|
import os |
|
|
import sys |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from diffusers import FlowMatchEulerDiscreteScheduler |
|
|
from omegaconf import OmegaConf |
|
|
from PIL import Image |
|
|
|
|
|
current_file_path = os.path.abspath(__file__) |
|
|
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
|
|
for project_root in project_roots: |
|
|
sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
|
|
|
from videox_fun.dist import set_multi_gpus_devices, shard_model |
|
|
from videox_fun.models import (AutoencoderKL, AutoTokenizer, |
|
|
Qwen3ForCausalLM, ZImageControlTransformer2DModel) |
|
|
from videox_fun.models.cache_utils import get_teacache_coefficients |
|
|
from videox_fun.pipeline import ZImageControlPipeline |
|
|
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
|
|
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
from videox_fun.utils.fp8_optimization import (convert_model_weight_to_float8, |
|
|
convert_weight_dtype_wrapper) |
|
|
from videox_fun.utils.lora_utils import merge_lora, unmerge_lora |
|
|
from videox_fun.utils.utils import (filter_kwargs, get_image_to_video_latent, get_image_latent, get_image, |
|
|
get_video_to_video_latent, |
|
|
save_videos_grid) |
|
|
|
|
|
from loguru import logger |
|
|
import onnx |
|
|
import subprocess |
|
|
|
|
|
|
|
|
def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx"): |
|
|
""" |
|
|
执行 onnxslim 命令压缩 ONNX 模型 |
|
|
""" |
|
|
try: |
|
|
|
|
|
cmd = ["onnxslim", input_file, output_file] |
|
|
|
|
|
print(f"执行命令: {' '.join(cmd)}") |
|
|
|
|
|
|
|
|
process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
text=True, |
|
|
bufsize=1, |
|
|
universal_newlines=True |
|
|
) |
|
|
|
|
|
|
|
|
for line in process.stdout: |
|
|
print(line, end='') |
|
|
|
|
|
|
|
|
stdout, stderr = process.communicate() |
|
|
|
|
|
if process.returncode != 0: |
|
|
print(f"命令执行失败, 错误信息:\n{stderr}") |
|
|
return False |
|
|
else: |
|
|
print("ONNX模型压缩完成!") |
|
|
return True |
|
|
|
|
|
except FileNotFoundError: |
|
|
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim") |
|
|
print("安装方法: pip install onnx-simplifier") |
|
|
return False |
|
|
except Exception as e: |
|
|
print(f"执行命令时发生错误: {e}") |
|
|
return False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GPU_memory_mode = "model_cpu_offload" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ulysses_degree = 1 |
|
|
ring_degree = 1 |
|
|
|
|
|
fsdp_dit = False |
|
|
fsdp_text_encoder = False |
|
|
|
|
|
|
|
|
compile_dit = False |
|
|
|
|
|
|
|
|
config_path = "config/z_image/z_image_control.yaml" |
|
|
|
|
|
model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" |
|
|
|
|
|
|
|
|
sampler_name = "Flow" |
|
|
|
|
|
|
|
|
transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" |
|
|
vae_path = None |
|
|
lora_path = None |
|
|
|
|
|
|
|
|
sample_size = [1728, 992] |
|
|
|
|
|
|
|
|
|
|
|
weight_dtype = torch.bfloat16 |
|
|
control_image = "asset/pose.jpg" |
|
|
control_context_scale = 0.75 |
|
|
|
|
|
|
|
|
|
|
|
prompt = "一位年轻女子站在阳光明媚的海岸线上, 白裙在轻拂的海风中微微飘动.她拥有一头鲜艳的紫色长发, 在风中轻盈舞动, 发间系着一个精致的黑色蝴蝶结, 与身后柔和的蔚蓝天空形成鲜明对比.她面容清秀, 眉目精致, 透着一股甜美的青春气息;神情柔和, 略带羞涩, 目光静静地凝望着远方的地平线, 双手自然交叠于身前, 仿佛沉浸在思绪之中.在她身后, 是辽阔无垠、波光粼粼的大海, 阳光洒在海面上, 映出温暖的金色光晕." |
|
|
|
|
|
negative_prompt = " " |
|
|
guidance_scale = 0.00 |
|
|
seed = 43 |
|
|
num_inference_steps = 9 |
|
|
lora_weight = 0.55 |
|
|
save_path = "samples/z-image-t2i-control" |
|
|
|
|
|
device = set_multi_gpus_devices(ulysses_degree, ring_degree) |
|
|
config = OmegaConf.load(config_path) |
|
|
|
|
|
transformer = ZImageControlTransformer2DModel.from_pretrained( |
|
|
model_name, |
|
|
subfolder="transformer", |
|
|
low_cpu_mem_usage=True, |
|
|
torch_dtype=weight_dtype, |
|
|
transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
|
|
).to(weight_dtype) |
|
|
|
|
|
if transformer_path is not None: |
|
|
print(f"From checkpoint: {transformer_path}") |
|
|
if transformer_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(transformer_path) |
|
|
else: |
|
|
state_dict = torch.load(transformer_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = transformer.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
if False: |
|
|
|
|
|
class DummyControlTransformerWrapper(torch.nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
self.control_context_scale = 0.75 |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
latent_model_input, |
|
|
timestep, |
|
|
prompt_embeds, |
|
|
control_context, |
|
|
): |
|
|
model_out = self.model( |
|
|
latent_model_input, |
|
|
timestep, |
|
|
prompt_embeds, |
|
|
control_context=control_context, |
|
|
control_context_scale=self.control_context_scale, |
|
|
) |
|
|
return model_out |
|
|
|
|
|
|
|
|
dummy_input = { |
|
|
"latent_model_input": [torch.randn(16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32)], |
|
|
"timestep": torch.tensor([0.], device="cpu", dtype=torch.float32), |
|
|
"prompt_embeds": [torch.randn(512, 2560, device="cpu", dtype=torch.float32)], |
|
|
"control_context": torch.randn(1, 16, 1, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32), |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
transformer_warpper = DummyControlTransformerWrapper(transformer) |
|
|
transformer_warpper.eval() |
|
|
|
|
|
transformer_path = "onnx-models/trans/" |
|
|
transformer_onnx_path = os.path.join(transformer_path, "z_image_control_transformer.onnx") |
|
|
if not os.path.exists(transformer_path): |
|
|
os.makedirs(transformer_path, exist_ok=True) |
|
|
|
|
|
torch.onnx.export( |
|
|
transformer_warpper.to(device="cpu", dtype=torch.float32), |
|
|
tuple(dummy_input.values()), |
|
|
transformer_onnx_path, |
|
|
opset_version=17, |
|
|
input_names=list(dummy_input.keys()), |
|
|
output_names=["model_out"], |
|
|
do_constant_folding=True, |
|
|
export_params=True, |
|
|
verbose=False |
|
|
) |
|
|
trans_onnx = onnx.load(transformer_onnx_path) |
|
|
|
|
|
simp_onnx_data = "onnx-models/z_image_control_transformer.onnx" |
|
|
onnx.save( |
|
|
trans_onnx, |
|
|
simp_onnx_data, |
|
|
save_as_external_data=True, |
|
|
all_tensors_to_one_file=True |
|
|
) |
|
|
|
|
|
logger.info("Transformer ONNX model exported, start to simplify.") |
|
|
|
|
|
success = run_onnxslim(simp_onnx_data, simp_onnx_data.replace(".onnx", "_slim.onnx")) |
|
|
if success: |
|
|
logger.info("Transformer ONNX model exported successfully.") |
|
|
else: |
|
|
sys.exit(1) |
|
|
|
|
|
exit() |
|
|
""" |
|
|
(Pdb) latent_model_input_list[0].shape |
|
|
torch.Size([16, 1, 216, 124]) |
|
|
(Pdb) timestep_model_input |
|
|
tensor([0.], device='cuda:0') |
|
|
(Pdb) prompt_embeds_model_input[0].shape |
|
|
torch.Size([165, 2560]) |
|
|
(Pdb) control_context.shape |
|
|
torch.Size([1, 16, 1, 216, 124]) |
|
|
(Pdb) control_context_scale |
|
|
0.75 |
|
|
(Pdb) model_out_list[0].shape |
|
|
torch.Size([16, 1, 216, 124]) |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
model_name, |
|
|
subfolder="vae" |
|
|
).to(weight_dtype) |
|
|
|
|
|
|
|
|
if False: |
|
|
|
|
|
class DummyVAEEncoderWrapper(torch.nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, x): |
|
|
latent_dist = self.model.encode(x)[0].mode() |
|
|
return latent_dist |
|
|
|
|
|
class DummyVAEDecoderWrapper(torch.nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, latents): |
|
|
image = self.model.decode(latents, return_dict=False)[0] |
|
|
return image |
|
|
|
|
|
def export_vae_onnx(vae, sample_size): |
|
|
|
|
|
vae.eval() |
|
|
if not os.path.exists("./onnx-models"): |
|
|
os.makedirs("./onnx-models", exist_ok=True) |
|
|
|
|
|
|
|
|
vae_encoder_onnx_path = "./onnx-models/vae_encoder.onnx" |
|
|
dummy_input = torch.randn(1, 3, sample_size[0], sample_size[1], device="cpu", dtype=torch.float32) |
|
|
vae_encode_wrapper = DummyVAEEncoderWrapper(vae) |
|
|
vae_encode_wrapper.eval() |
|
|
torch.onnx.export(vae_encode_wrapper.to(torch.float32), dummy_input, vae_encoder_onnx_path, opset_version=17) |
|
|
|
|
|
onnx.checker.check_model(vae_encoder_onnx_path) |
|
|
logger.info("VAE-Encoder ONNX model exported, start to simplify.") |
|
|
|
|
|
success = run_onnxslim(vae_encoder_onnx_path, vae_encoder_onnx_path.replace(".onnx", "_slim.onnx")) |
|
|
|
|
|
|
|
|
vae_decoder_onnx_path = "./onnx-models/vae_decoder.onnx" |
|
|
dummy_latent = torch.randn(1, vae.config.latent_channels, sample_size[0] // 8, sample_size[1] // 8, device="cpu", dtype=torch.float32) |
|
|
vae_decode_wrapper = DummyVAEDecoderWrapper(vae) |
|
|
vae_decode_wrapper.eval() |
|
|
torch.onnx.export(vae_decode_wrapper.to(torch.float32), dummy_latent, input_names=["latent"], output_names=["image"], f=vae_decoder_onnx_path, opset_version=17) |
|
|
onnx.checker.check_model(vae_decoder_onnx_path) |
|
|
logger.info("VAE-Decoder ONNX model exported, start to simplify.") |
|
|
success = run_onnxslim(vae_decoder_onnx_path, vae_decoder_onnx_path.replace(".onnx", "_slim.onnx")) |
|
|
|
|
|
if success: |
|
|
logger.info("VAE ONNX model exported successfully.") |
|
|
else: |
|
|
sys.exit(1) |
|
|
export_vae_onnx(vae, sample_size) |
|
|
exit() |
|
|
|
|
|
if vae_path is not None: |
|
|
print(f"From checkpoint: {vae_path}") |
|
|
if vae_path.endswith("safetensors"): |
|
|
from safetensors.torch import load_file, safe_open |
|
|
state_dict = load_file(vae_path) |
|
|
else: |
|
|
state_dict = torch.load(vae_path, map_location="cpu") |
|
|
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
|
|
|
m, u = vae.load_state_dict(state_dict, strict=False) |
|
|
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained( |
|
|
model_name, subfolder="tokenizer" |
|
|
) |
|
|
text_encoder = Qwen3ForCausalLM.from_pretrained( |
|
|
model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
|
|
|
if False: |
|
|
|
|
|
text_encoder.eval() |
|
|
text_encoder.config.use_cache = False |
|
|
text_encoder.config.output_attentions = False |
|
|
text_encoder.config.output_hidden_states = False |
|
|
|
|
|
|
|
|
import torch.nn as nn |
|
|
class Qwen3TextEncoderExporter(nn.Module): |
|
|
def __init__(self, model): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, input_ids, attention_mask=None): |
|
|
|
|
|
outputs = self.model( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_hidden_states=True, |
|
|
return_dict=True |
|
|
) |
|
|
return outputs.hidden_states[-2] |
|
|
|
|
|
wrapped_text_encoder = Qwen3TextEncoderExporter(text_encoder) |
|
|
wrapped_text_encoder.eval() |
|
|
|
|
|
|
|
|
max_sequence_length = 512 |
|
|
text_encoder_onnx_path = "./onnx-models/text_encoder.onnx" |
|
|
""" |
|
|
NOTE: 注意输入 onnx 的 mask 的 size 与 input_ids 的 size 一致. 前 N 个有效输入为 True, 后面的 padding 为 False. |
|
|
例如 input_ids 的 size 为 (1, 512), 实际有效输入长度为 20, 则 attention_mask 应该为: |
|
|
attention_mask = [True, True, ..., True, False, False, ..., False] # 共 512 个元素, 前 20 个为 True, 后 492 个为 False |
|
|
这样可以确保 ONNX 模型在推理时正确处理 padding 部分, 避免无效计算. |
|
|
""" |
|
|
input_ids = torch.randint(0, tokenizer.vocab_size, (1, max_sequence_length), device="cpu", dtype=torch.long) |
|
|
attention_mask = torch.ones((1, max_sequence_length), device="cpu", dtype=torch.long) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
test_output = wrapped_text_encoder(input_ids, attention_mask) |
|
|
print(f"测试输出形状: {test_output.shape}") |
|
|
import pdb; pdb.set_trace() |
|
|
|
|
|
torch.onnx.export( |
|
|
wrapped_text_encoder, |
|
|
(input_ids, attention_mask), |
|
|
text_encoder_onnx_path, |
|
|
opset_version=17, |
|
|
input_names=["input_ids", "attention_mask"], |
|
|
output_names=["last_hidden_state"], |
|
|
do_constant_folding=True, |
|
|
export_params=True, |
|
|
verbose=False |
|
|
) |
|
|
onnx.checker.check_model(text_encoder_onnx_path) |
|
|
logger.info("Text Encoder ONNX model exported successfully.") |
|
|
|
|
|
if not os.path.exists("./onnx-models"): |
|
|
os.makedirs("./onnx-models", exist_ok=True) |
|
|
|
|
|
logger.info("Text Encoder ONNX model exported, start to simplify.") |
|
|
|
|
|
success = run_onnxslim(text_encoder_onnx_path, text_encoder_onnx_path.replace(".onnx", "_slim.onnx")) |
|
|
if success: |
|
|
logger.info("Text Encoder ONNX model exported successfully.") |
|
|
else: |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
Chosen_Scheduler = scheduler_dict = { |
|
|
"Flow": FlowMatchEulerDiscreteScheduler, |
|
|
"Flow_Unipc": FlowUniPCMultistepScheduler, |
|
|
"Flow_DPM++": FlowDPMSolverMultistepScheduler, |
|
|
}[sampler_name] |
|
|
scheduler = Chosen_Scheduler.from_pretrained( |
|
|
model_name, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
pipeline = ZImageControlPipeline( |
|
|
vae=vae, |
|
|
tokenizer=tokenizer, |
|
|
text_encoder=text_encoder, |
|
|
transformer=transformer, |
|
|
scheduler=scheduler, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if GPU_memory_mode == "sequential_cpu_offload": |
|
|
pipeline.enable_sequential_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_cpu_offload": |
|
|
pipeline.enable_model_cpu_offload(device=device) |
|
|
elif GPU_memory_mode == "model_full_load_and_qfloat8": |
|
|
convert_model_weight_to_float8(transformer, exclude_module_name=["img_in", "txt_in", "timestep"], device=device) |
|
|
convert_weight_dtype_wrapper(transformer, weight_dtype) |
|
|
pipeline.to(device=device) |
|
|
else: |
|
|
pipeline.to(device=device) |
|
|
|
|
|
generator = torch.Generator(device=device).manual_seed(seed) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
if control_image is not None: |
|
|
control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
|
|
|
|
|
sample = pipeline( |
|
|
prompt = prompt, |
|
|
negative_prompt = negative_prompt, |
|
|
height = sample_size[0], |
|
|
width = sample_size[1], |
|
|
generator = generator, |
|
|
guidance_scale = guidance_scale, |
|
|
control_image = control_image, |
|
|
num_inference_steps = num_inference_steps, |
|
|
control_context_scale = control_context_scale, |
|
|
).images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_results(): |
|
|
if not os.path.exists(save_path): |
|
|
os.makedirs(save_path, exist_ok=True) |
|
|
|
|
|
index = len([path for path in os.listdir(save_path)]) + 1 |
|
|
prefix = str(index).zfill(8) |
|
|
video_path = os.path.join(save_path, prefix + ".png") |
|
|
image = sample[0] |
|
|
image.save(video_path) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_results() |