| | |
| | """Export the Z-Image control transformer to ONNX for inference.""" |
| |
|
| | import argparse |
| | import logging |
| | import os |
| | import sys |
| | from collections import OrderedDict |
| | from typing import Any, Dict, List, Optional, OrderedDict as OrderedDictType, Tuple |
| |
|
| | import numpy as np |
| | import torch |
| | from omegaconf import OmegaConf |
| |
|
| | from loguru import logger |
| | import onnx |
| | from onnx import numpy_helper |
| | import subprocess |
| |
|
| | REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
| | if REPO_ROOT not in sys.path: |
| | sys.path.insert(0, REPO_ROOT) |
| |
|
| | from videox_fun.models import ZImageControlTransformer2DModel |
| | from videox_fun.models.z_image_transformer2d import pad_stack |
| |
|
| | LOGGER = logging.getLogger("export_transformer_onnx") |
| | logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") |
| |
|
| | SEQ_MULTI_OF = 32 |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="Export the Z-Image control transformer to ONNX") |
| | parser.add_argument("--config", default="config/z_image/z_image_control.yaml", help="Path to the YAML config used to build the transformer") |
| | parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Directory that stores the original diffusers weights") |
| | parser.add_argument("--checkpoint", default="models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors", help="Optional fine-tuned checkpoint to load") |
| | parser.add_argument("--output", default="onnx-models/z_image_control_transformer.onnx", help="Target ONNX file path") |
| | parser.add_argument("--body-output", default="onnx-models/z_image_transformer_body.onnx", help="Path for the body-only ONNX when --split-control is enabled") |
| | parser.add_argument("--control-output", default="onnx-models/z_image_controlnet.onnx", help="Path for the control-only ONNX when --split-control is enabled") |
| | parser.add_argument("--height", type=int, default=864, help="Target image height used to derive latent resolution") |
| | parser.add_argument("--width", type=int, default=496, help="Target image width used to derive latent resolution") |
| | parser.add_argument("--batch-size", type=int, default=1, help="Batch size for the exported graph") |
| | parser.add_argument("--sequence-length", type=int, default=512, help="Prompt embedding sequence length (must be a multiple of 32)") |
| | parser.add_argument("--frames", type=int, default=1, help="Number of frames in the latent tensor") |
| | parser.add_argument("--latent-downsample-factor", type=int, default=8, help="Downsampling ratio between spatial image size and latent size") |
| | parser.add_argument("--latent-height", type=int, default=None, help="Override latent height (after downsampling)") |
| | parser.add_argument("--latent-width", type=int, default=None, help="Override latent width (after downsampling)") |
| | parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="Export precision") |
| | parser.add_argument("--control-scale", type=float, default=0.75, help="Default control context scale input") |
| | parser.add_argument("--patch-size", type=int, default=2, help="Spatial patch size used by the transformer") |
| | parser.add_argument("--f-patch-size", type=int, default=1, help="Frame patch size used by the transformer") |
| | parser.add_argument("--opset", type=int, default=17, help="ONNX opset version") |
| | parser.add_argument("--no-external-data", action="store_true", help="Disable external data format even if the model is larger than 2GB") |
| | parser.add_argument("--skip-ort-check", action="store_true", help="Skip running an ONNX Runtime correctness check") |
| | parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="ONNX Runtime provider used during validation") |
| | parser.add_argument("--split-control", action="store_true", help="Export transformer body and ControlNet separately instead of a fused model") |
| | parser.add_argument("--save-calib-inputs", action="store_true", help="Dump ONNX input dictionaries as .npy for calibration") |
| | parser.add_argument("--calib-dir", default="onnx-calibration", help="Directory for storing calibration npy files") |
| | parser.add_argument("--dynamic-axes", action="store_true", help="Export ONNX with dynamic batch/seq/latent dims; default is static shape") |
| | parser.add_argument("--skip-slim", action="store_true", help="Skip onnxslim simplification for faster debug export") |
| | return parser.parse_args() |
| |
|
| |
|
| | def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx"): |
| | """ |
| | 执行 onnxslim 命令压缩 ONNX 模型 |
| | """ |
| | try: |
| | |
| | cmd = ["onnxslim", input_file, output_file] |
| |
|
| | print(f"执行命令: {' '.join(cmd)}") |
| |
|
| | |
| | process = subprocess.Popen( |
| | cmd, |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.PIPE, |
| | text=True, |
| | bufsize=1, |
| | universal_newlines=True |
| | ) |
| |
|
| | |
| | for line in process.stdout: |
| | print(line, end='') |
| |
|
| | |
| | stdout, stderr = process.communicate() |
| |
|
| | if process.returncode != 0: |
| | print(f"命令执行失败, 错误信息:\n{stderr}") |
| | return False |
| | else: |
| | print("ONNX模型压缩完成!") |
| | return True |
| |
|
| | except FileNotFoundError: |
| | print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim") |
| | print("安装方法: pip install onnx-simplifier") |
| | return False |
| | except Exception as e: |
| | print(f"执行命令时发生错误: {e}") |
| | return False |
| |
|
| |
|
| | def _resolve_path(path: str) -> str: |
| | return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path |
| |
|
| |
|
| | def load_transformer(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> ZImageControlTransformer2DModel: |
| | config_path = _resolve_path(args.config) |
| | model_root = _resolve_path(args.model_root) |
| | checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None |
| |
|
| | if not os.path.exists(config_path): |
| | raise FileNotFoundError(f"Config not found: {config_path}") |
| | if not os.path.isdir(model_root): |
| | raise FileNotFoundError(f"Model root not found: {model_root}") |
| |
|
| | config = OmegaConf.load(config_path) |
| | transformer_kwargs = OmegaConf.to_container(config.get("transformer_additional_kwargs", {}), resolve=True) |
| |
|
| | LOGGER.info("Loading transformer from %s", model_root) |
| | transformer = ZImageControlTransformer2DModel.from_pretrained( |
| | model_root, |
| | subfolder="transformer", |
| | low_cpu_mem_usage=True, |
| | torch_dtype=torch_dtype, |
| | transformer_additional_kwargs=transformer_kwargs, |
| | ) |
| | transformer.eval() |
| | transformer.to(device=device, dtype=torch_dtype) |
| |
|
| | if checkpoint_path and os.path.exists(checkpoint_path): |
| | LOGGER.info("Loading checkpoint %s", checkpoint_path) |
| | if checkpoint_path.endswith(".safetensors"): |
| | from safetensors.torch import load_file |
| |
|
| | state_dict = load_file(checkpoint_path) |
| | else: |
| | state_dict = torch.load(checkpoint_path, map_location="cpu") |
| | state_dict = state_dict.get("state_dict", state_dict) |
| | missing, unexpected = transformer.load_state_dict(state_dict, strict=False) |
| | LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected)) |
| | elif checkpoint_path: |
| | raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
| |
|
| | return transformer |
| |
|
| |
|
| | def _tensor_batch_to_list(batch_tensor: torch.Tensor) -> List[torch.Tensor]: |
| | return [batch_tensor[i] for i in range(batch_tensor.shape[0])] |
| |
|
| |
|
| | def _prepare_transformer_state( |
| | model: ZImageControlTransformer2DModel, |
| | latent_list: List[torch.Tensor], |
| | prompt_list: List[torch.Tensor], |
| | timestep: torch.Tensor, |
| | patch_size: int, |
| | f_patch_size: int, |
| | ) -> Dict[str, Any]: |
| | bsz = len(latent_list) |
| | device = latent_list[0].device |
| | timestep = timestep.to(device=device, dtype=torch.float32) |
| | t = timestep * model.t_scale |
| | t = model.t_embedder(t) |
| |
|
| | ( |
| | x, |
| | cap_feats, |
| | x_size, |
| | x_pos_ids, |
| | cap_pos_ids, |
| | x_inner_pad_mask, |
| | cap_inner_pad_mask, |
| | ) = model.patchify_and_embed(latent_list, prompt_list, patch_size, f_patch_size) |
| |
|
| | |
| | x_item_seqlens = [len(_) for _ in x] |
| | assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens) |
| | x_max_item_seqlen = max(x_item_seqlens) |
| |
|
| | x = torch.cat(x, dim=0) |
| | x = model.all_x_embedder[f"{patch_size}-{f_patch_size}"](x) |
| | adaln_input = t.type_as(x) |
| |
|
| | mask = torch.cat(x_inner_pad_mask) |
| | if torch.onnx.is_in_onnx_export(): |
| | if model.x_pad_token.dim() == 1: |
| | x_pad_token_2d = model.x_pad_token.unsqueeze(0) |
| | else: |
| | x_pad_token_2d = model.x_pad_token |
| | mask_2d = mask.unsqueeze(1) |
| | x_pad_expanded = x_pad_token_2d.expand_as(x) |
| | x = torch.where(mask_2d, x_pad_expanded, x) |
| | else: |
| | x[mask] = model.x_pad_token |
| |
|
| | x = list(x.split(x_item_seqlens, dim=0)) |
| | x_freqs_cis = list(model.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0)) |
| |
|
| | x = pad_stack(x, x_max_item_seqlen, pad_value=0.0) |
| | x_freqs_cis = pad_stack(x_freqs_cis, x_max_item_seqlen, pad_value=0.0) |
| | x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) |
| | for i, seq_len in enumerate(x_item_seqlens): |
| | x_attn_mask[i, :seq_len] = 1 |
| |
|
| | for layer in model.noise_refiner: |
| | x = layer(x, x_attn_mask, x_freqs_cis, adaln_input) |
| |
|
| | |
| | cap_item_seqlens = [len(_) for _ in cap_feats] |
| | assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens) |
| | cap_max_item_seqlen = max(cap_item_seqlens) |
| |
|
| | cap_feats = torch.cat(cap_feats, dim=0) |
| | cap_feats = model.cap_embedder(cap_feats) |
| |
|
| | cap_mask = torch.cat(cap_inner_pad_mask) |
| | if torch.onnx.is_in_onnx_export(): |
| | if model.cap_pad_token.dim() == 1: |
| | cap_pad_token_2d = model.cap_pad_token.unsqueeze(0) |
| | else: |
| | cap_pad_token_2d = model.cap_pad_token |
| | mask_2d = cap_mask.unsqueeze(1) |
| | cap_pad_expanded = cap_pad_token_2d.expand_as(cap_feats) |
| | cap_feats = torch.where(mask_2d, cap_pad_expanded, cap_feats) |
| | else: |
| | cap_feats[cap_mask] = model.cap_pad_token |
| |
|
| | cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0)) |
| | cap_freqs_cis = list(model.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0)) |
| |
|
| | cap_feats = pad_stack(cap_feats, cap_max_item_seqlen, pad_value=0.0) |
| | cap_freqs_cis = pad_stack(cap_freqs_cis, cap_max_item_seqlen, pad_value=0.0) |
| | cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device) |
| | for i, seq_len in enumerate(cap_item_seqlens): |
| | cap_attn_mask[i, :seq_len] = 1 |
| |
|
| | for layer in model.context_refiner: |
| | cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis) |
| |
|
| | |
| | if model.sp_world_size > 1: |
| | x = torch.chunk(x, model.sp_world_size, dim=1)[model.sp_world_rank] |
| | x_item_seqlens = [len(_) for _ in x] |
| | x_max_item_seqlen = max(x_item_seqlens) |
| | x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device) |
| | for i, seq_len in enumerate(x_item_seqlens): |
| | x_attn_mask[i, :seq_len] = 1 |
| |
|
| | if x_freqs_cis is not None: |
| | x_freqs_cis = torch.chunk(x_freqs_cis, model.sp_world_size, dim=1)[model.sp_world_rank] |
| |
|
| | unified = [] |
| | unified_freqs_cis = [] |
| | for i in range(bsz): |
| | x_len = x_item_seqlens[i] |
| | cap_len = cap_item_seqlens[i] |
| | unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]])) |
| | unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]])) |
| | unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)] |
| | unified_max_item_seqlen = max(unified_item_seqlens) |
| |
|
| | unified = pad_stack(unified, unified_max_item_seqlen, pad_value=0.0) |
| | unified_freqs_cis = pad_stack(unified_freqs_cis, unified_max_item_seqlen, pad_value=0.0) |
| | unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device) |
| | for i, seq_len in enumerate(unified_item_seqlens): |
| | unified_attn_mask[i, :seq_len] = 1 |
| |
|
| | kwargs = dict(attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input) |
| |
|
| | return dict( |
| | x=x, |
| | cap_feats=cap_feats, |
| | unified=unified, |
| | kwargs=kwargs, |
| | adaln_input=adaln_input, |
| | time_embed=t, |
| | x_item_seqlens=x_item_seqlens, |
| | cap_item_seqlens=cap_item_seqlens, |
| | x_size=x_size, |
| | unified_attn_mask=unified_attn_mask, |
| | unified_freqs_cis=unified_freqs_cis, |
| | bsz=bsz, |
| | ) |
| |
|
| |
|
| | class TransformerOnnxWrapper(torch.nn.Module): |
| | """Lightweight wrapper that exposes tensor inputs for ONNX export.""" |
| |
|
| | def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int): |
| | super().__init__() |
| | self.model = model |
| | self.patch_size = patch_size |
| | self.f_patch_size = f_patch_size |
| |
|
| | def forward( |
| | self, |
| | latent_model_input: torch.Tensor, |
| | timestep: torch.Tensor, |
| | prompt_embeds: torch.Tensor, |
| | control_context: torch.Tensor, |
| | control_context_scale: torch.Tensor, |
| | ) -> torch.Tensor: |
| | latents = list(latent_model_input.unbind(dim=0)) |
| | prompts = list(prompt_embeds.unbind(dim=0)) |
| | scale = control_context_scale.to(device=latent_model_input.device, dtype=latent_model_input.dtype) |
| | outputs, _ = self.model( |
| | latents, |
| | timestep, |
| | prompts, |
| | patch_size=self.patch_size, |
| | f_patch_size=self.f_patch_size, |
| | control_context=control_context, |
| | control_context_scale=scale, |
| | ) |
| | return outputs |
| |
|
| |
|
| | class ControlNetWrapper(torch.nn.Module): |
| | """Exports only the control branch so it can be invoked on demand.""" |
| |
|
| | def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int): |
| | super().__init__() |
| | self.model = model |
| | self.patch_size = patch_size |
| | self.f_patch_size = f_patch_size |
| |
|
| | def forward( |
| | self, |
| | latent_model_input: torch.Tensor, |
| | timestep: torch.Tensor, |
| | prompt_embeds: torch.Tensor, |
| | control_context: torch.Tensor, |
| | ) -> torch.Tensor: |
| | latents = _tensor_batch_to_list(latent_model_input) |
| | prompts = _tensor_batch_to_list(prompt_embeds) |
| | control_list = _tensor_batch_to_list(control_context) |
| |
|
| | state = _prepare_transformer_state( |
| | self.model, |
| | latents, |
| | prompts, |
| | timestep, |
| | self.patch_size, |
| | self.f_patch_size, |
| | ) |
| |
|
| | hints = self.model.forward_control( |
| | state["unified"], |
| | state["cap_feats"], |
| | control_list, |
| | state["kwargs"], |
| | t=state["time_embed"], |
| | patch_size=self.patch_size, |
| | f_patch_size=self.f_patch_size, |
| | ) |
| | return torch.stack(hints, dim=0) |
| |
|
| |
|
| | class TransformerBodyWrapper(torch.nn.Module): |
| | """Exports the transformer body which consumes precomputed control hints.""" |
| |
|
| | def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int): |
| | super().__init__() |
| | self.model = model |
| | self.patch_size = patch_size |
| | self.f_patch_size = f_patch_size |
| |
|
| | def forward( |
| | self, |
| | latent_model_input: torch.Tensor, |
| | timestep: torch.Tensor, |
| | prompt_embeds: torch.Tensor, |
| | control_hints: torch.Tensor, |
| | control_context_scale: torch.Tensor, |
| | ) -> torch.Tensor: |
| | latents = _tensor_batch_to_list(latent_model_input) |
| | prompts = _tensor_batch_to_list(prompt_embeds) |
| | hints_list = list(torch.unbind(control_hints, dim=0)) |
| | scale = control_context_scale.to(device=latent_model_input.device, dtype=latent_model_input.dtype) |
| |
|
| | state = _prepare_transformer_state( |
| | self.model, |
| | latents, |
| | prompts, |
| | timestep, |
| | self.patch_size, |
| | self.f_patch_size, |
| | ) |
| |
|
| | unified = state["unified"] |
| | for layer in self.model.layers: |
| | layer_kwargs = dict( |
| | attn_mask=state["unified_attn_mask"], |
| | freqs_cis=state["unified_freqs_cis"], |
| | adaln_input=state["adaln_input"], |
| | hints=hints_list, |
| | context_scale=scale, |
| | ) |
| | unified = layer(unified, **layer_kwargs) |
| |
|
| | if self.model.sp_world_size > 1: |
| | unified_out = [] |
| | for i in range(state["bsz"]): |
| | x_len = state["x_item_seqlens"][i] |
| | unified_out.append(unified[i, :x_len]) |
| | unified = torch.stack(unified_out) |
| | unified = self.model.all_gather(unified, dim=1) |
| |
|
| | final_layer = self.model.all_final_layer[f"{self.patch_size}-{self.f_patch_size}"] |
| | unified = final_layer(unified, state["adaln_input"]) |
| | unified = list(unified.unbind(dim=0)) |
| | x = self.model.unpatchify(unified, state["x_size"], self.patch_size, self.f_patch_size) |
| | x = torch.stack(x) |
| | return x |
| |
|
| |
|
| | def _validate_sequence_length(seq_len: int) -> None: |
| | if seq_len % 32 != 0: |
| | raise ValueError("sequence_length must be a multiple of 32 to satisfy transformer padding rules") |
| |
|
| |
|
| | def _compute_latent_dims(args: argparse.Namespace) -> Dict[str, int]: |
| | if args.latent_height is not None and args.latent_width is not None: |
| | latent_h = args.latent_height |
| | latent_w = args.latent_width |
| | else: |
| | if args.height % args.latent_downsample_factor != 0 or args.width % args.latent_downsample_factor != 0: |
| | raise ValueError("height and width must be divisible by latent_downsample_factor") |
| | latent_h = args.height // args.latent_downsample_factor |
| | latent_w = args.width // args.latent_downsample_factor |
| | if latent_h % args.patch_size != 0 or latent_w % args.patch_size != 0: |
| | raise ValueError("latent dimensions must be divisible by patch_size") |
| | if args.frames % args.f_patch_size != 0: |
| | raise ValueError("frames must be divisible by f_patch_size") |
| | return {"latent_h": latent_h, "latent_w": latent_w} |
| |
|
| |
|
| | def build_dummy_inputs( |
| | args: argparse.Namespace, |
| | model: ZImageControlTransformer2DModel, |
| | torch_dtype: torch.dtype, |
| | device: torch.device, |
| | ) -> OrderedDictType[str, torch.Tensor]: |
| | _validate_sequence_length(args.sequence_length) |
| | dims = _compute_latent_dims(args) |
| | batch = args.batch_size |
| | in_channels = model.config.in_channels |
| | cap_dim = model.config.cap_feat_dim |
| |
|
| | latent = torch.randn( |
| | batch, |
| | in_channels, |
| | args.frames, |
| | dims["latent_h"], |
| | dims["latent_w"], |
| | dtype=torch_dtype, |
| | device=device, |
| | ) |
| | timestep = torch.linspace(0.0, 1.0, steps=batch, dtype=torch.float32, device=device) |
| | prompts = torch.randn( |
| | batch, |
| | args.sequence_length, |
| | cap_dim, |
| | dtype=torch_dtype, |
| | device=device, |
| | ) |
| | control = torch.randn_like(latent) |
| | control_scale = torch.full((1,), args.control_scale, dtype=torch.float32, device=device) |
| |
|
| | return OrderedDict( |
| | latent_model_input=latent, |
| | timestep=timestep, |
| | prompt_embeds=prompts, |
| | control_context=control, |
| | control_context_scale=control_scale, |
| | ) |
| |
|
| |
|
| | def maybe_save_calibration_inputs(tag: str, inputs: OrderedDictType[str, torch.Tensor], args: argparse.Namespace) -> Optional[str]: |
| | if not getattr(args, "save_calib_inputs", False): |
| | return None |
| | output_dir = _resolve_path(args.calib_dir) |
| | os.makedirs(output_dir, exist_ok=True) |
| | numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()} |
| | file_path = os.path.join(output_dir, f"{tag}_inputs.npy") |
| | np.save(file_path, numpy_dict, allow_pickle=True) |
| | LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path) |
| | return file_path |
| |
|
| |
|
| | def dump_initializer_parameters(model_path: str) -> str: |
| | """Save all ONNX initializers into a standalone .npz file.""" |
| | model_proto = onnx.load(model_path, load_external_data=True) |
| | param_dict = {} |
| | for initializer in model_proto.graph.initializer: |
| | param_dict[initializer.name] = numpy_helper.to_array(initializer) |
| | param_path = f"{model_path}.params.npz" |
| | np.savez(param_path, **param_dict) |
| | LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path) |
| | return param_path |
| |
|
| |
|
| | def export_onnx( |
| | wrapper: torch.nn.Module, |
| | sample_inputs: OrderedDictType[str, torch.Tensor], |
| | output_path: str, |
| | output_names: List[str], |
| | args: argparse.Namespace, |
| | ) -> Tuple[str, str]: |
| | export_path = _resolve_path(output_path) |
| | export_dir = os.path.dirname(export_path) |
| | if export_dir: |
| | os.makedirs(export_dir, exist_ok=True) |
| | input_names = list(sample_inputs.keys()) |
| | use_external = not args.no_external_data |
| | wrapper.eval() |
| |
|
| | dynamic_axes = None |
| | if args.dynamic_axes: |
| | dynamic_axes = { |
| | "latent_model_input": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| | "prompt_embeds": {0: "batch", 1: "seq_len"}, |
| | "timestep": {0: "batch"}, |
| | "control_context": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| | "control_hints": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| | "control_context_scale": {0: "scale_batch"}, |
| | "sample": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| | "hints": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"}, |
| | } |
| |
|
| | LOGGER.info("Exporting ONNX to %s", export_path) |
| | with torch.inference_mode(): |
| | torch.onnx.export( |
| | wrapper, |
| | args=tuple(sample_inputs[name] for name in input_names), |
| | f=export_path, |
| | input_names=input_names, |
| | output_names=output_names, |
| | opset_version=args.opset, |
| | do_constant_folding=True, |
| | export_params=True, |
| | dynamic_axes={k: v for k, v in dynamic_axes.items() if k in input_names + output_names} if dynamic_axes else None, |
| | |
| | ) |
| |
|
| | LOGGER.info("Raw ONNX export finished") |
| |
|
| | trans_onnx = onnx.load(export_path) |
| | simp_onnx_data = os.path.splitext(export_path)[0] + "_simp.onnx" |
| | onnx.save( |
| | trans_onnx, |
| | simp_onnx_data, |
| | save_as_external_data=True, |
| | all_tensors_to_one_file=True, |
| | ) |
| | external_weight_file = simp_onnx_data + ".data" |
| | LOGGER.info("Saved external-data ONNX to %s (weights -> %s)", simp_onnx_data, external_weight_file) |
| |
|
| | if args.skip_slim: |
| | LOGGER.info("Skip onnxslim as requested, using simplified external-data ONNX: %s", simp_onnx_data) |
| | final_onnx = simp_onnx_data |
| | else: |
| | slim_onnx_path = os.path.splitext(simp_onnx_data)[0] + "_slim.onnx" |
| | LOGGER.info("Transformer ONNX model exported, start to simplify via onnxslim") |
| | success = run_onnxslim(simp_onnx_data, slim_onnx_path) |
| | if not success: |
| | raise RuntimeError("onnxslim simplification failed, please check logs") |
| | final_onnx = slim_onnx_path |
| | LOGGER.info("Transformer ONNX model exported successfully: %s", final_onnx) |
| |
|
| | param_path = "" |
| | |
| |
|
| | return final_onnx, param_path |
| |
|
| |
|
| | def run_ort_validation( |
| | wrapper: torch.nn.Module, |
| | sample_inputs: OrderedDictType[str, torch.Tensor], |
| | onnx_path: str, |
| | provider: str, |
| | ) -> None: |
| | try: |
| | import onnxruntime as ort |
| | except ImportError: |
| | LOGGER.warning("onnxruntime not installed, skip validation") |
| | return |
| |
|
| | wrapper.eval() |
| | with torch.inference_mode(): |
| | torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy() |
| |
|
| | sess_options = ort.SessionOptions() |
| | session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=[provider]) |
| | ort_inputs = { |
| | name: tensor.detach().cpu().numpy() |
| | for name, tensor in sample_inputs.items() |
| | } |
| | ort_output = session.run(None, ort_inputs)[0] |
| |
|
| | abs_diff = np.max(np.abs(torch_output - ort_output)) |
| | rel_diff = abs_diff / (np.maximum(1.0, np.max(np.abs(torch_output)))) |
| | LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff) |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | device = torch.device("cpu") |
| | torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32 |
| |
|
| | torch.set_grad_enabled(False) |
| | transformer = load_transformer(args, torch_dtype, device) |
| | sample_inputs = build_dummy_inputs(args, transformer, torch_dtype, device) |
| |
|
| | if args.split_control: |
| | control_wrapper = ControlNetWrapper(transformer, args.patch_size, args.f_patch_size) |
| | body_wrapper = TransformerBodyWrapper(transformer, args.patch_size, args.f_patch_size) |
| |
|
| | control_inputs = OrderedDict( |
| | latent_model_input=sample_inputs["latent_model_input"], |
| | timestep=sample_inputs["timestep"], |
| | prompt_embeds=sample_inputs["prompt_embeds"], |
| | control_context=sample_inputs["control_context"], |
| | ) |
| |
|
| | maybe_save_calibration_inputs("controlnet", control_inputs, args) |
| |
|
| | with torch.inference_mode(): |
| | control_hints_sample = control_wrapper(*control_inputs.values()).detach() |
| |
|
| | control_model_path, _ = export_onnx( |
| | control_wrapper, |
| | control_inputs, |
| | args.control_output, |
| | ["hints"], |
| | args, |
| | ) |
| |
|
| | body_inputs = OrderedDict( |
| | latent_model_input=sample_inputs["latent_model_input"], |
| | timestep=sample_inputs["timestep"], |
| | prompt_embeds=sample_inputs["prompt_embeds"], |
| | control_hints=control_hints_sample, |
| | control_context_scale=sample_inputs["control_context_scale"], |
| | ) |
| |
|
| | maybe_save_calibration_inputs("transformer_body", body_inputs, args) |
| |
|
| | body_model_path, _ = export_onnx( |
| | body_wrapper, |
| | body_inputs, |
| | args.body_output, |
| | ["sample"], |
| | args, |
| | ) |
| |
|
| | if not args.skip_ort_check: |
| | try: |
| | run_ort_validation(control_wrapper, control_inputs, control_model_path, args.ort_provider) |
| | run_ort_validation(body_wrapper, body_inputs, body_model_path, args.ort_provider) |
| | except Exception as exc: |
| | LOGGER.warning("ONNX Runtime validation failed: %s", exc) |
| | else: |
| | wrapper = TransformerOnnxWrapper(transformer, args.patch_size, args.f_patch_size) |
| |
|
| | maybe_save_calibration_inputs("transformer", sample_inputs, args) |
| |
|
| | transformer_model_path, _ = export_onnx( |
| | wrapper, |
| | sample_inputs, |
| | args.output, |
| | ["sample"], |
| | args, |
| | ) |
| |
|
| | if not args.skip_ort_check: |
| | try: |
| | run_ort_validation(wrapper, sample_inputs, transformer_model_path, args.ort_provider) |
| | except Exception as exc: |
| | LOGGER.warning("ONNX Runtime validation failed: %s", exc) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | python scripts/z_image_fun/export_transformer_onnx.py \ |
| | --split-control \ |
| | --control-output onnx-models/z_image_controlnet.onnx \ |
| | --body-output onnx-models/z_image_transformer_body.onnx \ |
| | --save-calib-inputs \ |
| | --height 512 \ |
| | --width 512 \ |
| | --sequence-length 128 \ |
| | --latent-downsample-factor 8 \ |
| | --skip-slim \ |
| | --dtype fp32 |
| | """ |
| | main() |
| |
|