|
|
|
|
|
"""导出 VAE Encoder/Decoder 的 ONNX 模型。""" |
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from collections import OrderedDict |
|
|
from typing import Any, Dict, Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from loguru import logger |
|
|
import onnx |
|
|
from onnx import numpy_helper |
|
|
import subprocess |
|
|
|
|
|
REPO_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) |
|
|
if REPO_ROOT not in sys.path: |
|
|
sys.path.insert(0, REPO_ROOT) |
|
|
|
|
|
from videox_fun.models import AutoencoderKL |
|
|
|
|
|
LOGGER = logging.getLogger("export_vae_onnx") |
|
|
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] %(levelname)s: %(message)s") |
|
|
|
|
|
|
|
|
def parse_args() -> argparse.Namespace: |
|
|
parser = argparse.ArgumentParser(description="Export VAE encoder/decoder to ONNX") |
|
|
parser.add_argument("--model-root", default="models/Diffusion_Transformer/Z-Image-Turbo/", help="Diffusers 权重所在目录") |
|
|
parser.add_argument("--checkpoint", default=None, help="可选的 VAE finetune checkpoint") |
|
|
parser.add_argument("--encoder-output", default="onnx-models/vae_encoder.onnx", help="VAE Encoder ONNX 路径") |
|
|
parser.add_argument("--decoder-output", default="onnx-models/vae_decoder.onnx", help="VAE Decoder ONNX 路径") |
|
|
parser.add_argument("--height", type=int, default=864, help="导出时的图片高度") |
|
|
parser.add_argument("--width", type=int, default=496, help="导出时的图片宽度") |
|
|
parser.add_argument("--latent-downsample-factor", type=int, default=8, help="VAE 下采样倍数") |
|
|
parser.add_argument("--batch-size", type=int, default=1, help="导出 batch 大小") |
|
|
parser.add_argument("--dtype", choices=["fp16", "fp32"], default="fp16", help="导出精度") |
|
|
parser.add_argument("--opset", type=int, default=17, help="ONNX opset 版本") |
|
|
parser.add_argument("--dynamic-axes", action="store_true", help="是否导出动态维度") |
|
|
parser.add_argument("--skip-ort-check", action="store_true", help="跳过 onnxruntime 结果校验") |
|
|
parser.add_argument("--ort-provider", default="CPUExecutionProvider", help="onnxruntime provider") |
|
|
parser.add_argument("--skip-slim", action="store_true", help="跳过 onnxslim") |
|
|
parser.add_argument("--no-external-data", action="store_true", help="禁用外部数据格式保存") |
|
|
parser.add_argument("--save-calib-inputs", action="store_true", help="保存校准输入 npy") |
|
|
parser.add_argument("--calib-dir", default="onnx-calibration", help="校准输入保存目录") |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def run_onnxslim(input_file: str, output_file: str) -> bool: |
|
|
try: |
|
|
cmd = ["onnxslim", input_file, output_file] |
|
|
print(f"执行命令: {' '.join(cmd)}") |
|
|
process = subprocess.Popen( |
|
|
cmd, |
|
|
stdout=subprocess.PIPE, |
|
|
stderr=subprocess.PIPE, |
|
|
text=True, |
|
|
bufsize=1, |
|
|
universal_newlines=True, |
|
|
) |
|
|
for line in process.stdout: |
|
|
print(line, end="") |
|
|
_, stderr = process.communicate() |
|
|
if process.returncode != 0: |
|
|
print(f"命令执行失败, 错误信息:\n{stderr}") |
|
|
return False |
|
|
print("ONNX模型压缩完成!") |
|
|
return True |
|
|
except FileNotFoundError: |
|
|
print("错误: 未找到 onnxslim 命令, 请确保已安装 onnxslim (pip install onnx-simplifier)") |
|
|
return False |
|
|
except Exception as exc: |
|
|
print(f"执行命令时发生错误: {exc}") |
|
|
return False |
|
|
|
|
|
|
|
|
def _resolve_path(path: str) -> str: |
|
|
return os.path.abspath(os.path.join(REPO_ROOT, path)) if not os.path.isabs(path) else path |
|
|
|
|
|
|
|
|
def _check_image_dims(height: int, width: int, factor: int) -> None: |
|
|
if height % factor != 0 or width % factor != 0: |
|
|
raise ValueError("height 和 width 需要能被 latent_downsample_factor 整除") |
|
|
|
|
|
|
|
|
def _compute_latent_dims(height: int, width: int, factor: int) -> Dict[str, int]: |
|
|
_check_image_dims(height, width, factor) |
|
|
return {"latent_h": height // factor, "latent_w": width // factor} |
|
|
|
|
|
|
|
|
def load_vae(args: argparse.Namespace, torch_dtype: torch.dtype, device: torch.device) -> AutoencoderKL: |
|
|
model_root = _resolve_path(args.model_root) |
|
|
checkpoint_path = _resolve_path(args.checkpoint) if args.checkpoint else None |
|
|
if not os.path.isdir(model_root): |
|
|
raise FileNotFoundError(f"Model root not found: {model_root}") |
|
|
|
|
|
LOGGER.info("Loading VAE from %s", model_root) |
|
|
vae = AutoencoderKL.from_pretrained( |
|
|
model_root, |
|
|
subfolder="vae", |
|
|
torch_dtype=torch_dtype, |
|
|
low_cpu_mem_usage=True, |
|
|
) |
|
|
vae.to(device=device, dtype=torch_dtype) |
|
|
vae.eval() |
|
|
|
|
|
if checkpoint_path: |
|
|
if not os.path.exists(checkpoint_path): |
|
|
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}") |
|
|
LOGGER.info("Loading checkpoint %s", checkpoint_path) |
|
|
if checkpoint_path.endswith(".safetensors"): |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
state_dict = load_file(checkpoint_path) |
|
|
else: |
|
|
state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
state_dict = state_dict.get("state_dict", state_dict) |
|
|
missing, unexpected = vae.load_state_dict(state_dict, strict=False) |
|
|
LOGGER.info("Checkpoint loaded (missing=%d, unexpected=%d)", len(missing), len(unexpected)) |
|
|
|
|
|
return vae |
|
|
|
|
|
|
|
|
def build_dummy_inputs(args: argparse.Namespace, vae: AutoencoderKL, torch_dtype: torch.dtype, device: torch.device) -> OrderedDict: |
|
|
dims = _compute_latent_dims(args.height, args.width, args.latent_downsample_factor) |
|
|
pixel_values = torch.randn( |
|
|
args.batch_size, |
|
|
3, |
|
|
args.height, |
|
|
args.width, |
|
|
dtype=torch_dtype, |
|
|
device=device, |
|
|
) |
|
|
latents = torch.randn( |
|
|
args.batch_size, |
|
|
vae.config.latent_channels, |
|
|
dims["latent_h"], |
|
|
dims["latent_w"], |
|
|
dtype=torch_dtype, |
|
|
device=device, |
|
|
) |
|
|
return OrderedDict(pixel_values=pixel_values, latent=latents) |
|
|
|
|
|
|
|
|
class VAEEncoderWrapper(torch.nn.Module): |
|
|
def __init__(self, model: AutoencoderKL): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: |
|
|
latent_dist = self.model.encode(pixel_values)[0] |
|
|
return latent_dist.mode() |
|
|
|
|
|
|
|
|
class VAEDecoderWrapper(torch.nn.Module): |
|
|
def __init__(self, model: AutoencoderKL): |
|
|
super().__init__() |
|
|
self.model = model |
|
|
|
|
|
def forward(self, latents: torch.Tensor) -> torch.Tensor: |
|
|
image = self.model.decode(latents, return_dict=False)[0] |
|
|
return image |
|
|
|
|
|
|
|
|
def maybe_save_calibration_inputs(tag: str, inputs: OrderedDict, args: argparse.Namespace) -> Optional[str]: |
|
|
if not getattr(args, "save_calib_inputs", False): |
|
|
return None |
|
|
output_dir = _resolve_path(args.calib_dir) |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
numpy_dict = {name: tensor.detach().cpu().numpy() for name, tensor in inputs.items()} |
|
|
file_path = os.path.join(output_dir, f"{tag}_inputs.npy") |
|
|
np.save(file_path, numpy_dict, allow_pickle=True) |
|
|
LOGGER.info("Saved calibration inputs (%s) to %s", tag, file_path) |
|
|
return file_path |
|
|
|
|
|
|
|
|
def dump_initializer_parameters(model_path: str) -> str: |
|
|
model_proto = onnx.load(model_path, load_external_data=True) |
|
|
param_dict: Dict[str, Any] = {} |
|
|
for initializer in model_proto.graph.initializer: |
|
|
param_dict[initializer.name] = numpy_helper.to_array(initializer) |
|
|
param_path = f"{model_path}.params.npz" |
|
|
np.savez(param_path, **param_dict) |
|
|
LOGGER.info("Saved %d parameters to %s", len(param_dict), param_path) |
|
|
return param_path |
|
|
|
|
|
|
|
|
def export_onnx( |
|
|
wrapper: torch.nn.Module, |
|
|
sample_inputs: OrderedDict, |
|
|
output_path: str, |
|
|
output_names: list, |
|
|
args: argparse.Namespace, |
|
|
) -> str: |
|
|
export_path = _resolve_path(output_path) |
|
|
export_dir = os.path.dirname(export_path) |
|
|
if export_dir: |
|
|
os.makedirs(export_dir, exist_ok=True) |
|
|
|
|
|
input_names = list(sample_inputs.keys()) |
|
|
wrapper.eval() |
|
|
|
|
|
dynamic_axes = None |
|
|
if args.dynamic_axes: |
|
|
dynamic_axes = { |
|
|
"pixel_values": {0: "batch", 2: "height", 3: "width"}, |
|
|
"latents": {0: "batch", 2: "latent_h", 3: "latent_w"}, |
|
|
"images": {0: "batch", 2: "height", 3: "width"}, |
|
|
} |
|
|
|
|
|
LOGGER.info("Exporting ONNX to %s", export_path) |
|
|
with torch.inference_mode(): |
|
|
torch.onnx.export( |
|
|
wrapper, |
|
|
args=tuple(sample_inputs[name] for name in input_names), |
|
|
f=export_path, |
|
|
input_names=input_names, |
|
|
output_names=output_names, |
|
|
opset_version=args.opset, |
|
|
do_constant_folding=True, |
|
|
export_params=True, |
|
|
dynamic_axes={k: v for k, v in (dynamic_axes or {}).items() if k in input_names + output_names} if dynamic_axes else None, |
|
|
|
|
|
) |
|
|
|
|
|
LOGGER.info("Raw ONNX export finished") |
|
|
|
|
|
onnx_model = onnx.load(export_path) |
|
|
simp_onnx_path = os.path.splitext(export_path)[0] + "_simp.onnx" |
|
|
onnx.save( |
|
|
onnx_model, |
|
|
simp_onnx_path, |
|
|
save_as_external_data=True, |
|
|
all_tensors_to_one_file=True, |
|
|
) |
|
|
LOGGER.info("Saved external-data ONNX to %s", simp_onnx_path) |
|
|
|
|
|
if args.skip_slim: |
|
|
LOGGER.info("Skip onnxslim as requested") |
|
|
final_path = simp_onnx_path |
|
|
else: |
|
|
slim_path = os.path.splitext(simp_onnx_path)[0] + "_slim.onnx" |
|
|
LOGGER.info("Start onnxslim simplification") |
|
|
success = run_onnxslim(simp_onnx_path, slim_path) |
|
|
if not success: |
|
|
raise RuntimeError("onnxslim simplification failed") |
|
|
final_path = slim_path |
|
|
LOGGER.info("onnxslim done: %s", final_path) |
|
|
|
|
|
|
|
|
return final_path |
|
|
|
|
|
|
|
|
def run_ort_validation(wrapper: torch.nn.Module, sample_inputs: OrderedDict, onnx_path: str, provider: str) -> None: |
|
|
try: |
|
|
import onnxruntime as ort |
|
|
except ImportError: |
|
|
LOGGER.warning("onnxruntime not installed, skip validation") |
|
|
return |
|
|
|
|
|
wrapper.eval() |
|
|
with torch.inference_mode(): |
|
|
torch_output = wrapper(*sample_inputs.values()).detach().cpu().numpy() |
|
|
|
|
|
session = ort.InferenceSession(onnx_path, providers=[provider]) |
|
|
ort_inputs = {name: tensor.detach().cpu().numpy() for name, tensor in sample_inputs.items()} |
|
|
ort_output = session.run(None, ort_inputs)[0] |
|
|
|
|
|
abs_diff = np.max(np.abs(torch_output - ort_output)) |
|
|
rel_diff = abs_diff / max(1.0, float(np.max(np.abs(torch_output)))) |
|
|
LOGGER.info("ONNX Runtime check done (abs=%.6f, rel=%.6f)", abs_diff, rel_diff) |
|
|
|
|
|
|
|
|
def main() -> None: |
|
|
args = parse_args() |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
torch_dtype = torch.float16 if args.dtype == "fp16" else torch.float32 |
|
|
if torch_dtype == torch.float16 and device.type == "cpu": |
|
|
LOGGER.warning("CPU 上不支持 fp16, 自动回退为 fp32") |
|
|
torch_dtype = torch.float32 |
|
|
|
|
|
torch.set_grad_enabled(False) |
|
|
vae = load_vae(args, torch_dtype, device) |
|
|
sample_inputs = build_dummy_inputs(args, vae, torch_dtype, device) |
|
|
|
|
|
encoder_inputs = OrderedDict(pixel_values=sample_inputs["pixel_values"]) |
|
|
encoder_wrapper = VAEEncoderWrapper(vae) |
|
|
maybe_save_calibration_inputs("vae_encoder", encoder_inputs, args) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
latent_sample = encoder_wrapper(*encoder_inputs.values()).detach() |
|
|
|
|
|
encoder_onnx = export_onnx( |
|
|
encoder_wrapper, |
|
|
encoder_inputs, |
|
|
args.encoder_output, |
|
|
["latents"], |
|
|
args, |
|
|
) |
|
|
|
|
|
decoder_inputs = OrderedDict(latents=latent_sample) |
|
|
decoder_wrapper = VAEDecoderWrapper(vae) |
|
|
maybe_save_calibration_inputs("vae_decoder", decoder_inputs, args) |
|
|
|
|
|
decoder_onnx = export_onnx( |
|
|
decoder_wrapper, |
|
|
decoder_inputs, |
|
|
args.decoder_output, |
|
|
["images"], |
|
|
args, |
|
|
) |
|
|
|
|
|
if not args.skip_ort_check: |
|
|
try: |
|
|
run_ort_validation(encoder_wrapper, encoder_inputs, encoder_onnx, args.ort_provider) |
|
|
run_ort_validation(decoder_wrapper, decoder_inputs, decoder_onnx, args.ort_provider) |
|
|
except Exception as exc: |
|
|
LOGGER.warning("ONNX Runtime validation failed: %s", exc) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
""" |
|
|
示例: |
|
|
python scripts/z_image_fun/export_vae_onnx.py \ |
|
|
--model-root models/Diffusion_Transformer/Z-Image-Turbo/ \ |
|
|
--height 512 --width 512 \ |
|
|
--encoder-output onnx-models/vae_encoder.onnx \ |
|
|
--decoder-output onnx-models/vae_decoder.onnx \ |
|
|
--dtype fp32 \ |
|
|
--save-calib-inputs \ |
|
|
--skip-ort-check |
|
|
""" |
|
|
main() |
|
|
|