yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""导出 VAE Encoder/Decoder 的 ONNX 模型。"""
import argparse
import logging
import os
import sys
from collections import OrderedDict
from typing import Any, Dict, Optional
import numpy as np
import torch
from loguru import logger
import onnx
from onnx import numpy_helper
import subprocess
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))
if REPO_ROOT not in sys.path:
sys.path.insert(0, REPO_ROOT)
from videox_fun.models import AutoencoderKL # noqa: E402
LOGGER = logging.getLogger("export_vae_onnx")
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Export VAE encoder/decoder to ONNX")
parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Diffusers 权重所在目录")
parser.add_argument("--checkpoint", default=None, help="可选的 VAE finetune checkpoint")
parser.add_argument("--encoder-output", default="onnx-models/vae_encoder.onnx", help="VAE Encoder ONNX 路径")
parser.add_argument("--decoder-output", default="onnx-models/vae_decoder.onnx", help="VAE Decoder ONNX 路径")
parser.add_argument("--height", type=int, default=864, help="导出时的图片高度")
parser.add_argument("--width", type=int, default=496, help="导出时的图片宽度")
parser.add_argument("--latent-downsample-factor", type=int, default=8, help="VAE 下采样倍数")
parser.add_argument("--batch-size", type=int, default=1, help="导出 batch 大小")
parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="导出精度")
parser.add_argument("--opset", type=int, default=17, help="ONNX opset 版本")
parser.add_argument("--dynamic-axes", action="store_true", help="是否导出动态维度")
parser.add_argument("--skip-ort-check", action="store_true", help="跳过 onnxruntime 结果校验")
parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="onnxruntime provider")
parser.add_argument("--skip-slim", action="store_true", help="跳过 onnxslim")
parser.add_argument("--no-external-data", action="store_true", help="禁用外部数据格式保存")
parser.add_argument("--save-calib-inputs", action="store_true", help="保存校准输入 npy")
parser.add_argument("--calib-dir", default="onnx-calibration", help="校准输入保存目录")
return parser.parse_args()
def run_onnxslim(input_file: str, output_file: str) -> bool:
try:
cmd = ["onnxslim", input_file, output_file]
print(f"执行命令: {' '.join(cmd)}")
process = subprocess.Popen(
cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
bufsize=1,
universal_newlines=True,
)
for line in process.stdout:
print(line, end="")
_, stderr = process.communicate()
if process.returncode != 0:
print(f"命令执行失败, 错误信息:\n{stderr}")
return False
print("ONNX模型压缩完成!")
return True
except FileNotFoundError:
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)")
return False
except Exception as exc:
print(f"执行命令时发生错误: {exc}")
return False
def _resolve_path(path: str) -> str:
return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path
def _check_image_dims(height: int, width: int, factor: int) -> None:
if height % factor != 0 or width % factor != 0:
raise ValueError("height 和 width 需要能被 latent_downsample_factor 整除")
def _compute_latent_dims(height: int, width: int, factor: int) -> Dict[str, int]:
_check_image_dims(height, width, factor)
return {"latent_h": height // factor, "latent_w": width // factor}
def load_vae(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> AutoencoderKL:
model_root = _resolve_path(args.model_root)
checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None
if not os.path.isdir(model_root):
raise FileNotFoundError(f"Model root not found: {model_root}")
LOGGER.info("Loading VAE from %s", model_root)
vae = AutoencoderKL.from_pretrained(
model_root,
subfolder="vae",
torch_dtype=torch_dtype,
low_cpu_mem_usage=True,
)
vae.to(device=device, dtype=torch_dtype)
vae.eval()
if checkpoint_path:
if not os.path.exists(checkpoint_path):
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
LOGGER.info("Loading checkpoint %s", checkpoint_path)
if checkpoint_path.endswith(".safetensors"):
from safetensors.torch import load_file # type: ignore
state_dict = load_file(checkpoint_path)
else:
state_dict = torch.load(checkpoint_path, map_location="cpu")
state_dict = state_dict.get("state_dict", state_dict)
missing, unexpected = vae.load_state_dict(state_dict, strict=False)
LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected))
return vae
def build_dummy_inputs(args: argparse.Namespace, vae: AutoencoderKL, torch_dtype: torch.dtype, device: torch.device) -> OrderedDict:
dims = _compute_latent_dims(args.height, args.width, args.latent_downsample_factor)
pixel_values = torch.randn(
args.batch_size,
3,
args.height,
args.width,
dtype=torch_dtype,
device=device,
)
latents = torch.randn(
args.batch_size,
vae.config.latent_channels,
dims["latent_h"],
dims["latent_w"],
dtype=torch_dtype,
device=device,
)
return OrderedDict(pixel_values=pixel_values, latent=latents)
class VAEEncoderWrapper(torch.nn.Module):
def __init__(self, model: AutoencoderKL):
super().__init__()
self.model = model
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
latent_dist = self.model.encode(pixel_values)[0]
return latent_dist.mode()
class VAEDecoderWrapper(torch.nn.Module):
def __init__(self, model: AutoencoderKL):
super().__init__()
self.model = model
def forward(self, latents: torch.Tensor) -> torch.Tensor:
image = self.model.decode(latents, return_dict=False)[0]
return image
def maybe_save_calibration_inputs(tag: str, inputs: OrderedDict, args: argparse.Namespace) -> Optional[str]:
if not getattr(args, "save_calib_inputs", False):
return None
output_dir = _resolve_path(args.calib_dir)
os.makedirs(output_dir, exist_ok=True)
numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()}
file_path = os.path.join(output_dir, f"{tag}_inputs.npy")
np.save(file_path, numpy_dict, allow_pickle=True)
LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path)
return file_path
def dump_initializer_parameters(model_path: str) -> str:
model_proto = onnx.load(model_path, load_external_data=True)
param_dict: Dict[str, Any] = {}
for initializer in model_proto.graph.initializer:
param_dict[initializer.name] = numpy_helper.to_array(initializer)
param_path = f"{model_path}.params.npz"
np.savez(param_path, **param_dict)
LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path)
return param_path
def export_onnx(
wrapper: torch.nn.Module,
sample_inputs: OrderedDict,
output_path: str,
output_names: list,
args: argparse.Namespace,
) -> str:
export_path = _resolve_path(output_path)
export_dir = os.path.dirname(export_path)
if export_dir:
os.makedirs(export_dir, exist_ok=True)
input_names = list(sample_inputs.keys())
wrapper.eval()
dynamic_axes = None
if args.dynamic_axes:
dynamic_axes = {
"pixel_values": {0: "batch", 2: "height", 3: "width"},
"latents": {0: "batch", 2: "latent_h", 3: "latent_w"},
"images": {0: "batch", 2: "height", 3: "width"},
}
LOGGER.info("Exporting ONNX to %s", export_path)
with torch.inference_mode():
torch.onnx.export(
wrapper,
args=tuple(sample_inputs[name] for name in input_names),
f=export_path,
input_names=input_names,
output_names=output_names,
opset_version=args.opset,
do_constant_folding=True,
export_params=True,
dynamic_axes={k: v for k, v in (dynamic_axes or {}).items() if k in input_names + output_names} if dynamic_axes else None,
# use_external_data_format=not args.no_external_data,
)
LOGGER.info("Raw ONNX export finished")
onnx_model = onnx.load(export_path)
simp_onnx_path = os.path.splitext(export_path)[0] + "_simp.onnx"
onnx.save(
onnx_model,
simp_onnx_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
)
LOGGER.info("Saved external-data ONNX to %s", simp_onnx_path)
if args.skip_slim:
LOGGER.info("Skip onnxslim as requested")
final_path = simp_onnx_path
else:
slim_path = os.path.splitext(simp_onnx_path)[0] + "_slim.onnx"
LOGGER.info("Start onnxslim simplification")
success = run_onnxslim(simp_onnx_path, slim_path)
if not success:
raise RuntimeError("onnxslim simplification failed")
final_path = slim_path
LOGGER.info("onnxslim done: %s", final_path)
# dump_initializer_parameters(final_path)
return final_path
def run_ort_validation(wrapper: torch.nn.Module, sample_inputs: OrderedDict, onnx_path: str, provider: str) -> None:
try:
import onnxruntime as ort
except ImportError:
LOGGER.warning("onnxruntime not installed, skip validation")
return
wrapper.eval()
with torch.inference_mode():
torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy()
session = ort.InferenceSession(onnx_path, providers=[provider])
ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()}
ort_output = session.run(None, ort_inputs)[0]
abs_diff = np.max(np.abs(torch_output - ort_output))
rel_diff = abs_diff / max(1.0, float(np.max(np.abs(torch_output))))
LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff)
def main() -> None:
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32
if torch_dtype == torch.float16 and device.type == "cpu":
LOGGER.warning("CPU 上不支持 fp16, 自动回退为 fp32")
torch_dtype = torch.float32
torch.set_grad_enabled(False)
vae = load_vae(args, torch_dtype, device)
sample_inputs = build_dummy_inputs(args, vae, torch_dtype, device)
encoder_inputs = OrderedDict(pixel_values=sample_inputs["pixel_values"])
encoder_wrapper = VAEEncoderWrapper(vae)
maybe_save_calibration_inputs("vae_encoder", encoder_inputs, args)
with torch.inference_mode():
latent_sample = encoder_wrapper(*encoder_inputs.values()).detach()
encoder_onnx = export_onnx(
encoder_wrapper,
encoder_inputs,
args.encoder_output,
["latents"],
args,
)
decoder_inputs = OrderedDict(latents=latent_sample)
decoder_wrapper = VAEDecoderWrapper(vae)
maybe_save_calibration_inputs("vae_decoder", decoder_inputs, args)
decoder_onnx = export_onnx(
decoder_wrapper,
decoder_inputs,
args.decoder_output,
["images"],
args,
)
if not args.skip_ort_check:
try:
run_ort_validation(encoder_wrapper, encoder_inputs, encoder_onnx, args.ort_provider)
run_ort_validation(decoder_wrapper, decoder_inputs, decoder_onnx, args.ort_provider)
except Exception as exc:
LOGGER.warning("ONNX Runtime validation failed: %s", exc)
if __name__ == "__main__":
"""
示例:
python scripts/z_image_fun/export_vae_onnx.py \
--model-root models/Diffusion_Transformer/Z-Image-Turbo/ \
--height 512 --width 512 \
--encoder-output onnx-models/vae_encoder.onnx \
--decoder-output onnx-models/vae_decoder.onnx \
--dtype fp32 \
--save-calib-inputs \
--skip-ort-check
"""
main()