yongqiang
initialize this repo
ba96580
#!/usr/bin/env python3
"""使用 onnxruntime 推理链路 (transformer + VAE decoder)。"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
import numpy as np
import onnxruntime as ort
import torch
from tqdm import tqdm
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from loguru import logger
from omegaconf import OmegaConf
from PIL import Image
current_file_path = os.path.abspath(__file__)
project_roots = [
os.path.dirname(current_file_path),
os.path.dirname(os.path.dirname(current_file_path)),
os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))),
]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parents[2]
if REPO_ROOT.as_posix() not in sys.path:
sys.path.insert(0, REPO_ROOT.as_posix())
from videox_fun.models import AutoTokenizer, Qwen3ForCausalLM
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.utils.utils import get_image_latent
from scripts.split_quant_onnx_by_subconfigs import (
SubGraphSpec,
sanitize,
)
# -----------------------------------------------------------------------------
# 模型与资源路径
# -----------------------------------------------------------------------------
MODEL_NAME = "models/Diffusion_Transformer/Z-Image-Turbo/"
CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "config" / "z_image" / "z_image.yaml"
TRANSFORMER_CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "pulsar2_configs" / "transformers_subgraph.json"
TRANSFORMER_ONNX_DIR = REPO_ROOT / "VideoX-Fun" / "compiled_subgraph_from_onnx" / "frontend"
VAE_DECODER_ONNX = REPO_ROOT / "VideoX-Fun" / "compiled_output_vae_decoder" / "frontend" / "optimized.onnx"
SAVE_DIR = REPO_ROOT / "VideoX-Fun" / "samples" / "z-image-t2i-onnx"
# -----------------------------------------------------------------------------
# 运行配置
# -----------------------------------------------------------------------------
DEFAULT_PROMPTS = [
"(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading",
"nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot",
"sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape",
"modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render",
"classical oil painting of a renaissance noblewoman, chiaroscuro lighting, rich textures",
"macro photography of a dewdrop on a leaf, extreme detail, shallow depth of field",
"futuristic sports car parked under neon lights, glossy paint, cinematic 35mm look",
"ancient library with towering bookshelves, warm candlelight, dust motes in air",
"portrait of an astronaut in full suit, visor reflection showing earth, studio lighting",
"stormy sea with a lone lighthouse, crashing waves, dramatic clouds, long exposure feel",
"cybernetic samurai standing in rain, backlit silhouette, moody blue-orange palette",
"lush rainforest waterfall, soft mist, saturated greens, wide-angle composition",
"product shot of a smartwatch on marble, softbox lighting, crisp shadows, advertisement style",
"architectural exterior of a glass skyscraper at dusk, warm interior lights, reflections",
"vintage film photograph of a 1950s diner at night, grain and halation, neon signage",
"hyperrealistic bowl of ramen, steam rising, glossy broth, detailed toppings",
"fantasy castle on a floating island, waterfalls falling into clouds, sunset lighting",
"high-fashion editorial portrait, dramatic chiaroscuro, sharp focus on eyes",
"aerial view of winding river through autumn forest, golden and crimson leaves",
"studio shot of running shoes mid-air, motion blur trails, vibrant background gradient",
"noir city alley in the 1940s, hard shadows, rain-slick pavement, moody atmosphere",
"desert caravan at twilight, silhouettes of camels, soft purple sky, cinematic scope",
"close-up of a mechanical watch movement, intricate gears, metallic reflections",
"bioluminescent underwater reef, glowing corals, schools of fish, deep blue tones",
"portrait of an elderly man with weathered face, soft window light, fine skin detail",
"snowy village at night, warm cabin lights, smoke from chimneys, peaceful mood",
"futuristic data center aisle, cool cyan lighting, depth and symmetry",
"oil painting of a bowl of fruit in Dutch masters style, rich textures, dramatic lighting",
"sunlit meadow with wildflowers, shallow depth of field, pastel color palette",
"sci-fi corridor with volumetric light shafts, pristine white surfaces, wide lens",
"luxury wristwatch on black velvet, high contrast, advertisement macro shot",
"medieval marketplace at dawn, merchants setting up, soft warm light, lively details",
"((masterpiece,best quality))1 young beautiful girl,ultra detailed,official art,unity 8k wallpaper,masterpiece, best quality, official art, extremely detailed CG unity 8k wallpaper, highly detailed, 1 girl, aqua eyes, light smile, ((grey hair)), hair flower, bracelet, choker, ribbon, JK, look at viewer, on the beach, in summer,",
]
idx = random.randint(0, len(DEFAULT_PROMPTS) - 1)
PROMPT = DEFAULT_PROMPTS[idx]
NEG_PROMPT = " "
GUIDANCE_SCALE = 0.0
SEED = 42
HEIGHT, WIDTH = 512, 512
NUM_INFERENCE_STEPS = 9
NUM_CHANNELS_LATENTS = 16
VAE_SCALE_FACTOR = 8
PATCH_SIZE = 2
FPATCH_SIZE = 1
MAX_SEQ_LEN = 128
VAE_SCALING_FACTOR = 0.3611
VAE_SHIFT_FACTOR = 0.1159
SAMPLER_MAP = {
"Flow": FlowMatchEulerDiscreteScheduler,
"Flow_Unipc": FlowUniPCMultistepScheduler,
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
}
SAMPLER_NAME = "Flow"
# -----------------------------------------------------------------------------
# 工具函数
# -----------------------------------------------------------------------------
def _infer_module_device(module: torch.nn.Module) -> torch.device:
param = next(module.parameters(), None)
if param is not None:
return param.device
buffer = next(module.buffers(), None)
if buffer is not None:
return buffer.device
return torch.device("cpu")
@contextmanager
def module_to_device(module: torch.nn.Module, target_device: torch.device):
if module is None:
yield module
return
original_device = _infer_module_device(module)
target_device = target_device or original_device
needs_move = original_device != target_device
moved_to_cuda = needs_move and target_device.type == "cuda"
if needs_move:
module.to(target_device)
try:
yield module
finally:
if needs_move:
module.to(original_device)
if moved_to_cuda and torch.cuda.is_available():
cache_device = target_device.index or torch.cuda.current_device()
with torch.cuda.device(cache_device):
torch.cuda.empty_cache()
def _encode_prompt(
tokenizer: AutoTokenizer,
text_encoder: Qwen3ForCausalLM,
prompt: Union[str, List[str]],
device: torch.device,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if prompt_embeds is not None:
return prompt_embeds
prompts = [prompt] if isinstance(prompt, str) else list(prompt)
for idx, item in enumerate(prompts):
messages = [{"role": "user", "content": item}]
prompts[idx] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
)
text_inputs = tokenizer(
prompts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
with module_to_device(text_encoder, device):
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
return [prompt_embeds[i] for i in range(len(prompt_embeds))]
def encode_prompt(
tokenizer: AutoTokenizer,
text_encoder: Qwen3ForCausalLM,
prompt: Union[str, List[str]],
device: torch.device,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, List[str]]],
max_sequence_length: int,
) -> Tuple[List[torch.FloatTensor], List[torch.FloatTensor]]:
prompt_embeds = _encode_prompt(
tokenizer, text_encoder, prompt, device, None, max_sequence_length
)
negative_embeds: List[torch.FloatTensor] = []
if do_classifier_free_guidance:
neg = negative_prompt or ""
negative_list = [neg] if isinstance(neg, str) else list(neg)
negative_embeds = _encode_prompt(
tokenizer, text_encoder, negative_list, device, None, max_sequence_length
)
return prompt_embeds, negative_embeds
def _stack_prompt_embeddings(prompt_embeds_input):
if isinstance(prompt_embeds_input, list):
return torch.stack(prompt_embeds_input, dim=0)
return prompt_embeds_input
def prepare_latents(
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
) -> torch.FloatTensor:
height = 2 * (int(height) // (VAE_SCALE_FACTOR * 2))
width = 2 * (int(width) // (VAE_SCALE_FACTOR * 2))
shape = (batch_size, num_channels_latents, height, width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def calculate_shift(
image_seq_len: int,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> float:
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
return image_seq_len * m + b
def retrieve_timesteps(
scheduler,
num_inference_steps: int,
device: torch.device,
**kwargs,
):
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
return scheduler.timesteps
def resolve_providers(use_cuda: bool) -> List[str]:
available = ort.get_available_providers()
if use_cuda and "CUDAExecutionProvider" in available:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
# -----------------------------------------------------------------------------
# ONNX transformer 子图执行器
# -----------------------------------------------------------------------------
class OnnxSplitTransformer:
def __init__(self, config_path: Path, model_dir: Path, providers: List[str]):
self.config_path = config_path
self.model_dir = model_dir
self.providers = providers
self.specs = self._load_specs()
self.sessions, self.default_final_output = self._load_sessions_and_outputs()
def _load_specs(self) -> List[SubGraphSpec]:
with self.config_path.open("r", encoding="utf-8") as f:
config = json.load(f)
sub_configs = config.get("compiler", {}).get("sub_configs", [])
if not sub_configs:
raise ValueError("配置文件缺少 compiler.sub_configs")
specs: List[SubGraphSpec] = []
for idx, entry in enumerate(sub_configs):
start = [name for name in entry.get("start_tensor_names", []) if name]
end = [name for name in entry.get("end_tensor_names", []) if name]
if not start or not end:
raise ValueError(f"sub_config[{idx}] 定义不完整")
specs.append(
SubGraphSpec(
label=f"cfg_{idx:02d}",
start=start,
end=end,
node_names=set(),
source="config",
)
)
return specs
def _expected_path(self, spec: SubGraphSpec) -> Path:
head = sanitize(spec.start[0]) if spec.start else "const"
tail = sanitize(spec.end[0]) if spec.end else "out"
filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.onnx"
path = self.model_dir / filename
if not path.exists():
raise FileNotFoundError(f"缺少 ONNX 模型: {path}")
return path
def _load_sessions_and_outputs(self) -> Tuple[List[Tuple[SubGraphSpec, ort.InferenceSession]], str]:
sessions: List[Tuple[SubGraphSpec, ort.InferenceSession]] = []
# 先加载 config 中的子图
for spec in self.specs:
path = self._expected_path(spec)
sess = ort.InferenceSession(path.as_posix(), providers=self.providers)
sessions.append((spec, sess))
# 再尝试加载额外 auto_* 子图,输入/输出直接从 onnx 读取
auto_specs: List[SubGraphSpec] = []
for path in sorted(self.model_dir.glob("auto_*.onnx")):
sess = ort.InferenceSession(path.as_posix(), providers=self.providers)
inputs = [i.name for i in sess.get_inputs()]
outputs = [o.name for o in sess.get_outputs()]
auto_spec = SubGraphSpec(
label=path.stem,
start=inputs,
end=outputs,
node_names=set(),
source="auto",
)
auto_specs.append(auto_spec)
sessions.append((auto_spec, sess))
# 默认输出:优先使用最后一个 auto 子图的首个输出,否则用最后一个 config 子图的首个输出
if auto_specs:
default_output = auto_specs[-1].end[0]
else:
default_output = self.specs[-1].end[0]
# 将 auto_specs 也存进 self.specs,方便调度错误提示
self.specs = self.specs + auto_specs
return sessions, default_output
def __call__(
self,
latent_np: np.ndarray,
prompt_np: np.ndarray,
timestep_np: np.ndarray,
final_output_name: Optional[str] = None,
) -> np.ndarray:
tensor_store: Dict[str, np.ndarray] = {
"latent_model_input": latent_np,
"prompt_embeds": prompt_np,
"timestep": timestep_np,
}
executed = set()
target = final_output_name or self.default_final_output
# 动态调度:当所有输入就绪才运行对应子图
while target not in tensor_store:
progressed = False
for spec, session in self.sessions:
if spec.label in executed:
continue
if not all(name in tensor_store for name in spec.start):
continue
inputs = {name: tensor_store[name] for name in spec.start}
outputs = session.run(spec.end, inputs)
for out_name, value in zip(spec.end, outputs):
tensor_store[out_name] = value
executed.add(spec.label)
progressed = True
if not progressed:
missing = [
(spec.label, [name for name in spec.start if name not in tensor_store])
for spec, _ in self.sessions
if spec.label not in executed
]
raise RuntimeError(
f"子图调度中断,缺少输入: {missing}; 当前可用: {list(tensor_store.keys())}"
)
return tensor_store[target]
# -----------------------------------------------------------------------------
# 主流程
# -----------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="ONNXRuntime 推理 (transformer + VAE)")
parser.add_argument("--prompt", type=str, default=None, help="正向提示词,不填则使用预置随机样本")
parser.add_argument("--negative-prompt", type=str, default=NEG_PROMPT, help="反向提示词")
parser.add_argument("--steps", type=int, default=NUM_INFERENCE_STEPS, help="迭代步数")
parser.add_argument("--height", type=int, default=HEIGHT, help="生成高度,需被 16 整除")
parser.add_argument("--width", type=int, default=WIDTH, help="生成宽度,需被 16 整除")
parser.add_argument("--seed", type=int, default=SEED, help="随机种子")
parser.add_argument("--sampler", type=str, choices=list(SAMPLER_MAP.keys()), default=SAMPLER_NAME, help="采样器")
parser.add_argument("--max-seq-len", type=int, default=MAX_SEQ_LEN, help="最大文本长度")
parser.add_argument("--save-dir", type=str, default=str(SAVE_DIR), help="结果输出目录")
parser.add_argument("--transformer-config", type=str, default=str(TRANSFORMER_CONFIG_PATH), help="子图配置 json")
parser.add_argument("--transformer-subgraph-dir", type=str, default=str(TRANSFORMER_ONNX_DIR), help="子图 onnx 目录")
parser.add_argument("--vae-onnx", type=str, default=str(VAE_DECODER_ONNX), help="VAE decoder onnx 路径")
parser.add_argument("--use-cuda-provider", action="store_true", help="优先使用 CUDAExecutionProvider")
parser.add_argument("--save-decoder-input", action="store_true", help="是否保存 decoder 输入 npy")
parser.add_argument("--final-output-name", type=str, default=None, help="指定最终输出 tensor 名称,默认为最后一个子图的第一个输出")
parser.add_argument("--no-progress", action="store_true", help="关闭进度条输出")
return parser.parse_args()
def main() -> None:
args = parse_args()
prompt_text = args.prompt if args.prompt is not None else PROMPT
logger.info(f"使用的 prompt: {prompt_text}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_NAME,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
text_encoder.eval()
scheduler_cls = SAMPLER_MAP[args.sampler]
scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler")
image_processor = VaeImageProcessor(vae_scale_factor=VAE_SCALE_FACTOR * 2)
prompt_embeds, _ = encode_prompt(
tokenizer,
text_encoder,
prompt_text,
device,
do_classifier_free_guidance=False,
negative_prompt=args.negative_prompt,
max_sequence_length=args.max_seq_len,
)
latents = prepare_latents(
batch_size=1,
num_channels_latents=NUM_CHANNELS_LATENTS,
height=args.height,
width=args.width,
dtype=torch.float32,
device=device,
generator=torch.Generator(device=device).manual_seed(args.seed),
)
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
timesteps = retrieve_timesteps(scheduler, args.steps, device=device, mu=mu)
providers = resolve_providers(args.use_cuda_provider)
transformer_runner = OnnxSplitTransformer(
Path(args.transformer_config), Path(args.transformer_subgraph_dir), providers
)
final_output_name = args.final_output_name or transformer_runner.default_final_output
prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds)
iterator = timesteps if args.no_progress else tqdm(timesteps, desc="Denoising", dynamic_ncols=True)
for t in iterator:
timestep = t.expand(latents.shape[0])
timestep_model_input = (1000 - timestep) / 1000
latent_model_input = latents.to(torch.float32)
latent_np = latent_model_input.unsqueeze(2).to(dtype=torch.float32).cpu().numpy()
prompt_np = prompt_embeds_tensor.to(dtype=torch.float32).cpu().numpy()
timestep_np = timestep_model_input.to(dtype=torch.float32).cpu().numpy()
model_out = transformer_runner(latent_np, prompt_np, timestep_np, final_output_name)
if model_out.ndim == 5 and model_out.shape[2] == 1:
model_out = np.squeeze(model_out, axis=2)
model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=torch.float32)
if model_out_tensor.dim() == 5 and model_out_tensor.size(2) == 1:
model_out_tensor = model_out_tensor.squeeze(2)
noise_pred = -model_out_tensor
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
latents = (latents / VAE_SCALING_FACTOR) + VAE_SHIFT_FACTOR
decoder_input = latents.to(dtype=torch.float32).cpu().numpy()
if args.save_decoder_input:
SAVE_DIR_PATH = Path(args.save_dir)
SAVE_DIR_PATH.mkdir(parents=True, exist_ok=True)
np.save(SAVE_DIR_PATH / "decoder_input.npy", decoder_input)
logger.info("已保存 decoder 输入为 npy")
if decoder_input.ndim == 5 and decoder_input.shape[2] == 1:
decoder_input = np.squeeze(decoder_input, axis=2)
vae_session = ort.InferenceSession(Path(args.vae_onnx).as_posix(), providers=providers)
image = vae_session.run(None, {"latents": decoder_input})[0]
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
image = image_processor.postprocess(image, output_type="pil")
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
target_path = save_dir / f"z_image_onnx_{idx}.png"
image[0].save(target_path)
logger.info(f"ONNXRuntime 推理完成,结果保存到 {target_path}")
if __name__ == "__main__":
"""
# 512x512 生成示例命令:
python examples/z_image_fun/launcher_onnx.py \
--transformer-config pulsar2_configs/transformers_subgraph.json \
--transformer-subgraph-dir transformers_body_only_split_onnx --vae-onnx onnx-models/vae_decoder_simp_slim.onnx
# 1728x992 生成示例命令:
python examples/z_image_fun/launcher_onnx.py \
--transformer-config pulsar2_configs/transformers_subgraph_1728x992.json \
--transformer-subgraph-dir transformers_body_only_1728_992_split_onnx \
--vae-onnx onnx-models-1728x992/vae_decoder_simp_slim.onnx \
--max-seq-len 256 \
--height 1728 --width 992
"""
main()