File size: 28,192 Bytes
ba96580 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 | #!/usr/bin/env python3
"""Export the Z-Image control transformer to ONNX for inference."""
import argparse
import logging
import os
import sys
from collections import OrderedDict
from typing import Any, Dict, List, Optional, OrderedDict as OrderedDictType, Tuple
import numpy as np
import torch
from omegaconf import OmegaConf
from loguru import logger
import onnx
from onnx import numpy_helper
import subprocess
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
from videox_fun.models import ZImageControlTransformer2DModel # noqa: E402
from videox_fun.models.z_image_transformer2d import pad_stack # noqa: E402
LOGGER = logging.getLogger("export_transformer_onnx")
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
SEQ_MULTI_OF = 32
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export the Z-Image control transformer to ONNX")
parser.add_argument("--config", default="config/z_image/z_image_control.yaml", help="Path to the YAML config used to build the transformer")
parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Directory that stores the original diffusers weights")
parser.add_argument("--checkpoint", default="models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors", help="Optional fine-tuned checkpoint to load")
parser.add_argument("--output", default="onnx-models/z_image_control_transformer.onnx", help="Target ONNX file path")
parser.add_argument("--body-output", default="onnx-models/z_image_transformer_body.onnx", help="Path for the body-only ONNX when --split-control is enabled")
parser.add_argument("--control-output", default="onnx-models/z_image_controlnet.onnx", help="Path for the control-only ONNX when --split-control is enabled")
parser.add_argument("--height", type=int, default=864, help="Target image height used to derive latent resolution")
parser.add_argument("--width", type=int, default=496, help="Target image width used to derive latent resolution")
parser.add_argument("--batch-size", type=int, default=1, help="Batch size for the exported graph")
parser.add_argument("--sequence-length", type=int, default=512, help="Prompt embedding sequence length (must be a multiple of 32)")
parser.add_argument("--frames", type=int, default=1, help="Number of frames in the latent tensor")
parser.add_argument("--latent-downsample-factor", type=int, default=8, help="Downsampling ratio between spatial image size and latent size")
parser.add_argument("--latent-height", type=int, default=None, help="Override latent height (after downsampling)")
parser.add_argument("--latent-width", type=int, default=None, help="Override latent width (after downsampling)")
parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="Export precision")
parser.add_argument("--control-scale", type=float, default=0.75, help="Default control context scale input")
parser.add_argument("--patch-size", type=int, default=2, help="Spatial patch size used by the transformer")
parser.add_argument("--f-patch-size", type=int, default=1, help="Frame patch size used by the transformer")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset version")
parser.add_argument("--no-external-data", action="store_true", help="Disable external data format even if the model is larger than 2GB")
parser.add_argument("--skip-ort-check", action="store_true", help="Skip running an ONNX Runtime correctness check")
parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="ONNX Runtime provider used during validation")
parser.add_argument("--split-control", action="store_true", help="Export transformer body and ControlNet separately instead of a fused model")
parser.add_argument("--save-calib-inputs", action="store_true", help="Dump ONNX input dictionaries as .npy for calibration")
parser.add_argument("--calib-dir", default="onnx-calibration", help="Directory for storing calibration npy files")
parser.add_argument("--dynamic-axes", action="store_true", help="Export ONNX with dynamic batch/seq/latent dims; default is static shape")
parser.add_argument("--skip-slim", action="store_true", help="Skip onnxslim simplification for faster debug export")
return parser.parse_args()
def run_onnxslim(input_file="vae.onnx", output_file="vae_slim.onnx"):
"""
执行 onnxslim 命令压缩 ONNX 模型
"""
try:
# 使用完整的命令路径(如果知道的话)
cmd = ["onnxslim", input_file, output_file]
print(f"执行命令: {' '.join(cmd)}")
# 执行命令, 实时输出
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True
)
# 实时打印输出
for line in process.stdout:
print(line, end='')
# 等待命令完成
stdout, stderr = process.communicate()
if process.returncode != 0:
print(f"命令执行失败, 错误信息:\n{stderr}")
return False
else:
print("ONNX模型压缩完成!")
return True
except FileNotFoundError:
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim")
print("安装方法: pip install onnx-simplifier")
return False
except Exception as e:
print(f"执行命令时发生错误: {e}")
return False
def _resolve_path(path: str) -> str:
return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path
def load_transformer(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> ZImageControlTransformer2DModel:
config_path = _resolve_path(args.config)
model_root = _resolve_path(args.model_root)
checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None
if not os.path.exists(config_path):
raise FileNotFoundError(f"Config not found: {config_path}")
if not os.path.isdir(model_root):
raise FileNotFoundError(f"Model root not found: {model_root}")
config = OmegaConf.load(config_path)
transformer_kwargs = OmegaConf.to_container(config.get("transformer_additional_kwargs", {}), resolve=True)
LOGGER.info("Loading transformer from %s", model_root)
transformer = ZImageControlTransformer2DModel.from_pretrained(
model_root,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=torch_dtype,
transformer_additional_kwargs=transformer_kwargs,
)
transformer.eval()
transformer.to(device=device, dtype=torch_dtype)
if checkpoint_path and os.path.exists(checkpoint_path):
LOGGER.info("Loading checkpoint %s", checkpoint_path)
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file # type: ignore
state_dict = load_file(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = state_dict.get("state_dict", state_dict)
missing, unexpected = transformer.load_state_dict(state_dict, strict=False)
LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected))
elif checkpoint_path:
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
return transformer
def _tensor_batch_to_list(batch_tensor: torch.Tensor) -> List[torch.Tensor]:
return [batch_tensor[i] for i in range(batch_tensor.shape[0])]
def _prepare_transformer_state(
model: ZImageControlTransformer2DModel,
latent_list: List[torch.Tensor],
prompt_list: List[torch.Tensor],
timestep: torch.Tensor,
patch_size: int,
f_patch_size: int,
) -> Dict[str, Any]:
bsz = len(latent_list)
device = latent_list[0].device
timestep = timestep.to(device=device, dtype=torch.float32)
t = timestep * model.t_scale
t = model.t_embedder(t)
(
x,
cap_feats,
x_size,
x_pos_ids,
cap_pos_ids,
x_inner_pad_mask,
cap_inner_pad_mask,
) = model.patchify_and_embed(latent_list, prompt_list, patch_size, f_patch_size)
# Latent tokens refinement
x_item_seqlens = [len(_) for _ in x]
assert all(_ % SEQ_MULTI_OF == 0 for _ in x_item_seqlens)
x_max_item_seqlen = max(x_item_seqlens)
x = torch.cat(x, dim=0)
x = model.all_x_embedder[f"{patch_size}-{f_patch_size}"](x)
adaln_input = t.type_as(x)
mask = torch.cat(x_inner_pad_mask)
if torch.onnx.is_in_onnx_export():
if model.x_pad_token.dim() == 1:
x_pad_token_2d = model.x_pad_token.unsqueeze(0)
else:
x_pad_token_2d = model.x_pad_token
mask_2d = mask.unsqueeze(1)
x_pad_expanded = x_pad_token_2d.expand_as(x)
x = torch.where(mask_2d, x_pad_expanded, x)
else:
x[mask] = model.x_pad_token
x = list(x.split(x_item_seqlens, dim=0))
x_freqs_cis = list(model.rope_embedder(torch.cat(x_pos_ids, dim=0)).split(x_item_seqlens, dim=0))
x = pad_stack(x, x_max_item_seqlen, pad_value=0.0)
x_freqs_cis = pad_stack(x_freqs_cis, x_max_item_seqlen, pad_value=0.0)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
for layer in model.noise_refiner:
x = layer(x, x_attn_mask, x_freqs_cis, adaln_input)
# Caption refinement
cap_item_seqlens = [len(_) for _ in cap_feats]
assert all(_ % SEQ_MULTI_OF == 0 for _ in cap_item_seqlens)
cap_max_item_seqlen = max(cap_item_seqlens)
cap_feats = torch.cat(cap_feats, dim=0)
cap_feats = model.cap_embedder(cap_feats)
cap_mask = torch.cat(cap_inner_pad_mask)
if torch.onnx.is_in_onnx_export():
if model.cap_pad_token.dim() == 1:
cap_pad_token_2d = model.cap_pad_token.unsqueeze(0)
else:
cap_pad_token_2d = model.cap_pad_token
mask_2d = cap_mask.unsqueeze(1)
cap_pad_expanded = cap_pad_token_2d.expand_as(cap_feats)
cap_feats = torch.where(mask_2d, cap_pad_expanded, cap_feats)
else:
cap_feats[cap_mask] = model.cap_pad_token
cap_feats = list(cap_feats.split(cap_item_seqlens, dim=0))
cap_freqs_cis = list(model.rope_embedder(torch.cat(cap_pos_ids, dim=0)).split(cap_item_seqlens, dim=0))
cap_feats = pad_stack(cap_feats, cap_max_item_seqlen, pad_value=0.0)
cap_freqs_cis = pad_stack(cap_freqs_cis, cap_max_item_seqlen, pad_value=0.0)
cap_attn_mask = torch.zeros((bsz, cap_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(cap_item_seqlens):
cap_attn_mask[i, :seq_len] = 1
for layer in model.context_refiner:
cap_feats = layer(cap_feats, cap_attn_mask, cap_freqs_cis)
# Context parallel handling
if model.sp_world_size > 1:
x = torch.chunk(x, model.sp_world_size, dim=1)[model.sp_world_rank]
x_item_seqlens = [len(_) for _ in x]
x_max_item_seqlen = max(x_item_seqlens)
x_attn_mask = torch.zeros((bsz, x_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(x_item_seqlens):
x_attn_mask[i, :seq_len] = 1
if x_freqs_cis is not None:
x_freqs_cis = torch.chunk(x_freqs_cis, model.sp_world_size, dim=1)[model.sp_world_rank]
unified = []
unified_freqs_cis = []
for i in range(bsz):
x_len = x_item_seqlens[i]
cap_len = cap_item_seqlens[i]
unified.append(torch.cat([x[i][:x_len], cap_feats[i][:cap_len]]))
unified_freqs_cis.append(torch.cat([x_freqs_cis[i][:x_len], cap_freqs_cis[i][:cap_len]]))
unified_item_seqlens = [a + b for a, b in zip(cap_item_seqlens, x_item_seqlens)]
unified_max_item_seqlen = max(unified_item_seqlens)
unified = pad_stack(unified, unified_max_item_seqlen, pad_value=0.0)
unified_freqs_cis = pad_stack(unified_freqs_cis, unified_max_item_seqlen, pad_value=0.0)
unified_attn_mask = torch.zeros((bsz, unified_max_item_seqlen), dtype=torch.bool, device=device)
for i, seq_len in enumerate(unified_item_seqlens):
unified_attn_mask[i, :seq_len] = 1
kwargs = dict(attn_mask=unified_attn_mask, freqs_cis=unified_freqs_cis, adaln_input=adaln_input)
return dict(
x=x,
cap_feats=cap_feats,
unified=unified,
kwargs=kwargs,
adaln_input=adaln_input,
time_embed=t,
x_item_seqlens=x_item_seqlens,
cap_item_seqlens=cap_item_seqlens,
x_size=x_size,
unified_attn_mask=unified_attn_mask,
unified_freqs_cis=unified_freqs_cis,
bsz=bsz,
)
class TransformerOnnxWrapper(torch.nn.Module):
"""Lightweight wrapper that exposes tensor inputs for ONNX export."""
def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int):
super().__init__()
self.model = model
self.patch_size = patch_size
self.f_patch_size = f_patch_size
def forward(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor,
prompt_embeds: torch.Tensor,
control_context: torch.Tensor,
control_context_scale: torch.Tensor,
) -> torch.Tensor:
latents = list(latent_model_input.unbind(dim=0))
prompts = list(prompt_embeds.unbind(dim=0))
scale = control_context_scale.to(device=latent_model_input.device, dtype=latent_model_input.dtype)
outputs, _ = self.model(
latents,
timestep,
prompts,
patch_size=self.patch_size,
f_patch_size=self.f_patch_size,
control_context=control_context,
control_context_scale=scale,
)
return outputs
class ControlNetWrapper(torch.nn.Module):
"""Exports only the control branch so it can be invoked on demand."""
def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int):
super().__init__()
self.model = model
self.patch_size = patch_size
self.f_patch_size = f_patch_size
def forward(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor,
prompt_embeds: torch.Tensor,
control_context: torch.Tensor,
) -> torch.Tensor:
latents = _tensor_batch_to_list(latent_model_input)
prompts = _tensor_batch_to_list(prompt_embeds)
control_list = _tensor_batch_to_list(control_context)
state = _prepare_transformer_state(
self.model,
latents,
prompts,
timestep,
self.patch_size,
self.f_patch_size,
)
hints = self.model.forward_control(
state["unified"],
state["cap_feats"],
control_list,
state["kwargs"],
t=state["time_embed"],
patch_size=self.patch_size,
f_patch_size=self.f_patch_size,
)
return torch.stack(hints, dim=0)
class TransformerBodyWrapper(torch.nn.Module):
"""Exports the transformer body which consumes precomputed control hints."""
def __init__(self, model: ZImageControlTransformer2DModel, patch_size: int, f_patch_size: int):
super().__init__()
self.model = model
self.patch_size = patch_size
self.f_patch_size = f_patch_size
def forward(
self,
latent_model_input: torch.Tensor,
timestep: torch.Tensor,
prompt_embeds: torch.Tensor,
control_hints: torch.Tensor,
control_context_scale: torch.Tensor,
) -> torch.Tensor:
latents = _tensor_batch_to_list(latent_model_input)
prompts = _tensor_batch_to_list(prompt_embeds)
hints_list = list(torch.unbind(control_hints, dim=0))
scale = control_context_scale.to(device=latent_model_input.device, dtype=latent_model_input.dtype)
state = _prepare_transformer_state(
self.model,
latents,
prompts,
timestep,
self.patch_size,
self.f_patch_size,
)
unified = state["unified"]
for layer in self.model.layers:
layer_kwargs = dict(
attn_mask=state["unified_attn_mask"],
freqs_cis=state["unified_freqs_cis"],
adaln_input=state["adaln_input"],
hints=hints_list,
context_scale=scale,
)
unified = layer(unified, **layer_kwargs)
if self.model.sp_world_size > 1:
unified_out = []
for i in range(state["bsz"]):
x_len = state["x_item_seqlens"][i]
unified_out.append(unified[i, :x_len])
unified = torch.stack(unified_out)
unified = self.model.all_gather(unified, dim=1)
final_layer = self.model.all_final_layer[f"{self.patch_size}-{self.f_patch_size}"]
unified = final_layer(unified, state["adaln_input"])
unified = list(unified.unbind(dim=0))
x = self.model.unpatchify(unified, state["x_size"], self.patch_size, self.f_patch_size)
x = torch.stack(x)
return x
def _validate_sequence_length(seq_len: int) -> None:
if seq_len % 32 != 0:
raise ValueError("sequence_length must be a multiple of 32 to satisfy transformer padding rules")
def _compute_latent_dims(args: argparse.Namespace) -> Dict[str, int]:
if args.latent_height is not None and args.latent_width is not None:
latent_h = args.latent_height
latent_w = args.latent_width
else:
if args.height % args.latent_downsample_factor != 0 or args.width % args.latent_downsample_factor != 0:
raise ValueError("height and width must be divisible by latent_downsample_factor")
latent_h = args.height // args.latent_downsample_factor
latent_w = args.width // args.latent_downsample_factor
if latent_h % args.patch_size != 0 or latent_w % args.patch_size != 0:
raise ValueError("latent dimensions must be divisible by patch_size")
if args.frames % args.f_patch_size != 0:
raise ValueError("frames must be divisible by f_patch_size")
return {"latent_h": latent_h, "latent_w": latent_w}
def build_dummy_inputs(
args: argparse.Namespace,
model: ZImageControlTransformer2DModel,
torch_dtype: torch.dtype,
device: torch.device,
) -> OrderedDictType[str, torch.Tensor]:
_validate_sequence_length(args.sequence_length)
dims = _compute_latent_dims(args)
batch = args.batch_size
in_channels = model.config.in_channels
cap_dim = model.config.cap_feat_dim
latent = torch.randn(
batch,
in_channels,
args.frames,
dims["latent_h"],
dims["latent_w"],
dtype=torch_dtype,
device=device,
)
timestep = torch.linspace(0.0, 1.0, steps=batch, dtype=torch.float32, device=device)
prompts = torch.randn(
batch,
args.sequence_length,
cap_dim,
dtype=torch_dtype,
device=device,
)
control = torch.randn_like(latent)
control_scale = torch.full((1,), args.control_scale, dtype=torch.float32, device=device)
return OrderedDict(
latent_model_input=latent,
timestep=timestep,
prompt_embeds=prompts,
control_context=control,
control_context_scale=control_scale,
)
def maybe_save_calibration_inputs(tag: str, inputs: OrderedDictType[str, torch.Tensor], args: argparse.Namespace) -> Optional[str]:
if not getattr(args, "save_calib_inputs", False):
return None
output_dir = _resolve_path(args.calib_dir)
os.makedirs(output_dir, exist_ok=True)
numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()}
file_path = os.path.join(output_dir, f"{tag}_inputs.npy")
np.save(file_path, numpy_dict, allow_pickle=True)
LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path)
return file_path
def dump_initializer_parameters(model_path: str) -> str:
"""Save all ONNX initializers into a standalone .npz file."""
model_proto = onnx.load(model_path, load_external_data=True)
param_dict = {}
for initializer in model_proto.graph.initializer:
param_dict[initializer.name] = numpy_helper.to_array(initializer)
param_path = f"{model_path}.params.npz"
np.savez(param_path, **param_dict)
LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path)
return param_path
def export_onnx(
wrapper: torch.nn.Module,
sample_inputs: OrderedDictType[str, torch.Tensor],
output_path: str,
output_names: List[str],
args: argparse.Namespace,
) -> Tuple[str, str]:
export_path = _resolve_path(output_path)
export_dir = os.path.dirname(export_path)
if export_dir:
os.makedirs(export_dir, exist_ok=True)
input_names = list(sample_inputs.keys())
use_external = not args.no_external_data
wrapper.eval()
dynamic_axes = None
if args.dynamic_axes:
dynamic_axes = {
"latent_model_input": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
"prompt_embeds": {0: "batch", 1: "seq_len"},
"timestep": {0: "batch"},
"control_context": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
"control_hints": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
"control_context_scale": {0: "scale_batch"},
"sample": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
"hints": {0: "batch", 2: "frames", 3: "latent_h", 4: "latent_w"},
}
LOGGER.info("Exporting ONNX to %s", export_path)
with torch.inference_mode():
torch.onnx.export(
wrapper,
args=tuple(sample_inputs[name] for name in input_names),
f=export_path,
input_names=input_names,
output_names=output_names,
opset_version=args.opset,
do_constant_folding=True,
export_params=True,
dynamic_axes={k: v for k, v in dynamic_axes.items() if k in input_names + output_names} if dynamic_axes else None,
# use_external_data_format=use_external,
)
LOGGER.info("Raw ONNX export finished")
trans_onnx = onnx.load(export_path)
simp_onnx_data = os.path.splitext(export_path)[0] + "_simp.onnx"
onnx.save(
trans_onnx,
simp_onnx_data,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
external_weight_file = simp_onnx_data + ".data"
LOGGER.info("Saved external-data ONNX to %s (weights -> %s)", simp_onnx_data, external_weight_file)
if args.skip_slim:
LOGGER.info("Skip onnxslim as requested, using simplified external-data ONNX: %s", simp_onnx_data)
final_onnx = simp_onnx_data
else:
slim_onnx_path = os.path.splitext(simp_onnx_data)[0] + "_slim.onnx"
LOGGER.info("Transformer ONNX model exported, start to simplify via onnxslim")
success = run_onnxslim(simp_onnx_data, slim_onnx_path)
if not success:
raise RuntimeError("onnxslim simplification failed, please check logs")
final_onnx = slim_onnx_path
LOGGER.info("Transformer ONNX model exported successfully: %s", final_onnx)
param_path = ""
# param_path = dump_initializer_parameters(final_onnx) # 当前看来无必要保存
return final_onnx, param_path
def run_ort_validation(
wrapper: torch.nn.Module,
sample_inputs: OrderedDictType[str, torch.Tensor],
onnx_path: str,
provider: str,
) -> None:
try:
import onnxruntime as ort
except ImportError: # pragma: no cover
LOGGER.warning("onnxruntime not installed, skip validation")
return
wrapper.eval()
with torch.inference_mode():
torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy()
sess_options = ort.SessionOptions()
session = ort.InferenceSession(onnx_path, sess_options=sess_options, providers=[provider])
ort_inputs = {
name: tensor.detach().cpu().numpy()
for name, tensor in sample_inputs.items()
}
ort_output = session.run(None, ort_inputs)[0]
abs_diff = np.max(np.abs(torch_output - ort_output))
rel_diff = abs_diff / (np.maximum(1.0, np.max(np.abs(torch_output))))
LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff)
def main() -> None:
args = parse_args()
device = torch.device("cpu")
torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32
torch.set_grad_enabled(False)
transformer = load_transformer(args, torch_dtype, device)
sample_inputs = build_dummy_inputs(args, transformer, torch_dtype, device)
if args.split_control:
control_wrapper = ControlNetWrapper(transformer, args.patch_size, args.f_patch_size)
body_wrapper = TransformerBodyWrapper(transformer, args.patch_size, args.f_patch_size)
control_inputs = OrderedDict(
latent_model_input=sample_inputs["latent_model_input"],
timestep=sample_inputs["timestep"],
prompt_embeds=sample_inputs["prompt_embeds"],
control_context=sample_inputs["control_context"],
)
maybe_save_calibration_inputs("controlnet", control_inputs, args)
with torch.inference_mode():
control_hints_sample = control_wrapper(*control_inputs.values()).detach()
control_model_path, _ = export_onnx(
control_wrapper,
control_inputs,
args.control_output,
["hints"],
args,
)
body_inputs = OrderedDict(
latent_model_input=sample_inputs["latent_model_input"],
timestep=sample_inputs["timestep"],
prompt_embeds=sample_inputs["prompt_embeds"],
control_hints=control_hints_sample,
control_context_scale=sample_inputs["control_context_scale"],
)
maybe_save_calibration_inputs("transformer_body", body_inputs, args)
body_model_path, _ = export_onnx(
body_wrapper,
body_inputs,
args.body_output,
["sample"],
args,
)
if not args.skip_ort_check:
try:
run_ort_validation(control_wrapper, control_inputs, control_model_path, args.ort_provider)
run_ort_validation(body_wrapper, body_inputs, body_model_path, args.ort_provider)
except Exception as exc: # pragma: no cover
LOGGER.warning("ONNX Runtime validation failed: %s", exc)
else:
wrapper = TransformerOnnxWrapper(transformer, args.patch_size, args.f_patch_size)
maybe_save_calibration_inputs("transformer", sample_inputs, args)
transformer_model_path, _ = export_onnx(
wrapper,
sample_inputs,
args.output,
["sample"],
args,
)
if not args.skip_ort_check:
try:
run_ort_validation(wrapper, sample_inputs, transformer_model_path, args.ort_provider)
except Exception as exc: # pragma: no cover
LOGGER.warning("ONNX Runtime validation failed: %s", exc)
if __name__ == "__main__":
"""
python scripts/z_image_fun/export_transformer_onnx.py \
--split-control \
--control-output onnx-models/z_image_controlnet.onnx \
--body-output onnx-models/z_image_transformer_body.onnx \
--save-calib-inputs \
--height 512 \
--width 512 \
--sequence-length 128 \
--latent-downsample-factor 8 \
--skip-slim \
--dtype fp32
"""
main()
|