#!/usr/bin/env python3 """导出 VAE Encoder/Decoder 的 ONNX 模型。""" import argparse import logging import os import sys from collections import OrderedDict from typing import Any, Dict, Optional import numpy as np import torch from loguru import logger import onnx from onnx import numpy_helper import subprocess REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) if REPO_ROOT not in sys.path: sys.path.insert(0, REPO_ROOT) from videox_fun.models import AutoencoderKL # noqa: E402 LOGGER = logging.getLogger("export_vae_onnx") logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Export VAE encoder/decoder to ONNX") parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Diffusers 权重所在目录") parser.add_argument("--checkpoint", default=None, help="可选的 VAE finetune checkpoint") parser.add_argument("--encoder-output", default="onnx-models/vae_encoder.onnx", help="VAE Encoder ONNX 路径") parser.add_argument("--decoder-output", default="onnx-models/vae_decoder.onnx", help="VAE Decoder ONNX 路径") parser.add_argument("--height", type=int, default=864, help="导出时的图片高度") parser.add_argument("--width", type=int, default=496, help="导出时的图片宽度") parser.add_argument("--latent-downsample-factor", type=int, default=8, help="VAE 下采样倍数") parser.add_argument("--batch-size", type=int, default=1, help="导出 batch 大小") parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="导出精度") parser.add_argument("--opset", type=int, default=17, help="ONNX opset 版本") parser.add_argument("--dynamic-axes", action="store_true", help="是否导出动态维度") parser.add_argument("--skip-ort-check", action="store_true", help="跳过 onnxruntime 结果校验") parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="onnxruntime provider") parser.add_argument("--skip-slim", action="store_true", help="跳过 onnxslim") parser.add_argument("--no-external-data", action="store_true", help="禁用外部数据格式保存") parser.add_argument("--save-calib-inputs", action="store_true", help="保存校准输入 npy") parser.add_argument("--calib-dir", default="onnx-calibration", help="校准输入保存目录") return parser.parse_args() def run_onnxslim(input_file: str, output_file: str) -> bool: try: cmd = ["onnxslim", input_file, output_file] print(f"执行命令: {' '.join(cmd)}") process = subprocess.Popen( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, universal_newlines=True, ) for line in process.stdout: print(line, end="") _, stderr = process.communicate() if process.returncode != 0: print(f"命令执行失败, 错误信息:\n{stderr}") return False print("ONNX模型压缩完成!") return True except FileNotFoundError: print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)") return False except Exception as exc: print(f"执行命令时发生错误: {exc}") return False def _resolve_path(path: str) -> str: return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path def _check_image_dims(height: int, width: int, factor: int) -> None: if height % factor != 0 or width % factor != 0: raise ValueError("height 和 width 需要能被 latent_downsample_factor 整除") def _compute_latent_dims(height: int, width: int, factor: int) -> Dict[str, int]: _check_image_dims(height, width, factor) return {"latent_h": height // factor, "latent_w": width // factor} def load_vae(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> AutoencoderKL: model_root = _resolve_path(args.model_root) checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None if not os.path.isdir(model_root): raise FileNotFoundError(f"Model root not found: {model_root}") LOGGER.info("Loading VAE from %s", model_root) vae = AutoencoderKL.from_pretrained( model_root, subfolder="vae", torch_dtype=torch_dtype, low_cpu_mem_usage=True, ) vae.to(device=device, dtype=torch_dtype) vae.eval() if checkpoint_path: if not os.path.exists(checkpoint_path): raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") LOGGER.info("Loading checkpoint %s", checkpoint_path) if checkpoint_path.endswith(".safetensors"): from safetensors.torch import load_file # type: ignore state_dict = load_file(checkpoint_path) else: state_dict = torch.load(checkpoint_path, map_location="cpu") state_dict = state_dict.get("state_dict", state_dict) missing, unexpected = vae.load_state_dict(state_dict, strict=False) LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected)) return vae def build_dummy_inputs(args: argparse.Namespace, vae: AutoencoderKL, torch_dtype: torch.dtype, device: torch.device) -> OrderedDict: dims = _compute_latent_dims(args.height, args.width, args.latent_downsample_factor) pixel_values = torch.randn( args.batch_size, 3, args.height, args.width, dtype=torch_dtype, device=device, ) latents = torch.randn( args.batch_size, vae.config.latent_channels, dims["latent_h"], dims["latent_w"], dtype=torch_dtype, device=device, ) return OrderedDict(pixel_values=pixel_values, latent=latents) class VAEEncoderWrapper(torch.nn.Module): def __init__(self, model: AutoencoderKL): super().__init__() self.model = model def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: latent_dist = self.model.encode(pixel_values)[0] return latent_dist.mode() class VAEDecoderWrapper(torch.nn.Module): def __init__(self, model: AutoencoderKL): super().__init__() self.model = model def forward(self, latents: torch.Tensor) -> torch.Tensor: image = self.model.decode(latents, return_dict=False)[0] return image def maybe_save_calibration_inputs(tag: str, inputs: OrderedDict, args: argparse.Namespace) -> Optional[str]: if not getattr(args, "save_calib_inputs", False): return None output_dir = _resolve_path(args.calib_dir) os.makedirs(output_dir, exist_ok=True) numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()} file_path = os.path.join(output_dir, f"{tag}_inputs.npy") np.save(file_path, numpy_dict, allow_pickle=True) LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path) return file_path def dump_initializer_parameters(model_path: str) -> str: model_proto = onnx.load(model_path, load_external_data=True) param_dict: Dict[str, Any] = {} for initializer in model_proto.graph.initializer: param_dict[initializer.name] = numpy_helper.to_array(initializer) param_path = f"{model_path}.params.npz" np.savez(param_path, **param_dict) LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path) return param_path def export_onnx( wrapper: torch.nn.Module, sample_inputs: OrderedDict, output_path: str, output_names: list, args: argparse.Namespace, ) -> str: export_path = _resolve_path(output_path) export_dir = os.path.dirname(export_path) if export_dir: os.makedirs(export_dir, exist_ok=True) input_names = list(sample_inputs.keys()) wrapper.eval() dynamic_axes = None if args.dynamic_axes: dynamic_axes = { "pixel_values": {0: "batch", 2: "height", 3: "width"}, "latents": {0: "batch", 2: "latent_h", 3: "latent_w"}, "images": {0: "batch", 2: "height", 3: "width"}, } LOGGER.info("Exporting ONNX to %s", export_path) with torch.inference_mode(): torch.onnx.export( wrapper, args=tuple(sample_inputs[name] for name in input_names), f=export_path, input_names=input_names, output_names=output_names, opset_version=args.opset, do_constant_folding=True, export_params=True, dynamic_axes={k: v for k, v in (dynamic_axes or {}).items() if k in input_names + output_names} if dynamic_axes else None, # use_external_data_format=not args.no_external_data, ) LOGGER.info("Raw ONNX export finished") onnx_model = onnx.load(export_path) simp_onnx_path = os.path.splitext(export_path)[0] + "_simp.onnx" onnx.save( onnx_model, simp_onnx_path, save_as_external_data=True, all_tensors_to_one_file=True, ) LOGGER.info("Saved external-data ONNX to %s", simp_onnx_path) if args.skip_slim: LOGGER.info("Skip onnxslim as requested") final_path = simp_onnx_path else: slim_path = os.path.splitext(simp_onnx_path)[0] + "_slim.onnx" LOGGER.info("Start onnxslim simplification") success = run_onnxslim(simp_onnx_path, slim_path) if not success: raise RuntimeError("onnxslim simplification failed") final_path = slim_path LOGGER.info("onnxslim done: %s", final_path) # dump_initializer_parameters(final_path) return final_path def run_ort_validation(wrapper: torch.nn.Module, sample_inputs: OrderedDict, onnx_path: str, provider: str) -> None: try: import onnxruntime as ort except ImportError: LOGGER.warning("onnxruntime not installed, skip validation") return wrapper.eval() with torch.inference_mode(): torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy() session = ort.InferenceSession(onnx_path, providers=[provider]) ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()} ort_output = session.run(None, ort_inputs)[0] abs_diff = np.max(np.abs(torch_output - ort_output)) rel_diff = abs_diff / max(1.0, float(np.max(np.abs(torch_output)))) LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff) def main() -> None: args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32 if torch_dtype == torch.float16 and device.type == "cpu": LOGGER.warning("CPU 上不支持 fp16, 自动回退为 fp32") torch_dtype = torch.float32 torch.set_grad_enabled(False) vae = load_vae(args, torch_dtype, device) sample_inputs = build_dummy_inputs(args, vae, torch_dtype, device) encoder_inputs = OrderedDict(pixel_values=sample_inputs["pixel_values"]) encoder_wrapper = VAEEncoderWrapper(vae) maybe_save_calibration_inputs("vae_encoder", encoder_inputs, args) with torch.inference_mode(): latent_sample = encoder_wrapper(*encoder_inputs.values()).detach() encoder_onnx = export_onnx( encoder_wrapper, encoder_inputs, args.encoder_output, ["latents"], args, ) decoder_inputs = OrderedDict(latents=latent_sample) decoder_wrapper = VAEDecoderWrapper(vae) maybe_save_calibration_inputs("vae_decoder", decoder_inputs, args) decoder_onnx = export_onnx( decoder_wrapper, decoder_inputs, args.decoder_output, ["images"], args, ) if not args.skip_ort_check: try: run_ort_validation(encoder_wrapper, encoder_inputs, encoder_onnx, args.ort_provider) run_ort_validation(decoder_wrapper, decoder_inputs, decoder_onnx, args.ort_provider) except Exception as exc: LOGGER.warning("ONNX Runtime validation failed: %s", exc) if __name__ == "__main__": """ 示例: python scripts/z_image_fun/export_vae_onnx.py \ --model-root models/Diffusion_Transformer/Z-Image-Turbo/ \ --height 512 --width 512 \ --encoder-output onnx-models/vae_encoder.onnx \ --decoder-output onnx-models/vae_decoder.onnx \ --dtype fp32 \ --save-calib-inputs \ --skip-ort-check """ main()