| |
| """使用 onnxruntime 推理链路 (transformer + VAE decoder)。""" |
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| from contextlib import contextmanager |
| from pathlib import Path |
| from typing import Dict, Iterable, List, Optional, Tuple, Union |
|
|
| import numpy as np |
| import onnxruntime as ort |
| import torch |
| from tqdm import tqdm |
| from diffusers import FlowMatchEulerDiscreteScheduler |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.utils.torch_utils import randn_tensor |
| from loguru import logger |
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
| current_file_path = os.path.abspath(__file__) |
| project_roots = [ |
| os.path.dirname(current_file_path), |
| os.path.dirname(os.path.dirname(current_file_path)), |
| os.path.dirname(os.path.dirname(os.path.dirname(current_file_path))), |
| ] |
| for project_root in project_roots: |
| sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent |
| REPO_ROOT = SCRIPT_DIR.parents[2] |
| if REPO_ROOT.as_posix() not in sys.path: |
| sys.path.insert(0, REPO_ROOT.as_posix()) |
|
|
| from videox_fun.models import AutoTokenizer, Qwen3ForCausalLM |
| from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| from videox_fun.utils.utils import get_image_latent |
| from scripts.split_quant_onnx_by_subconfigs import ( |
| SubGraphSpec, |
| sanitize, |
| ) |
|
|
| |
| |
| |
| MODEL_NAME = "models/Diffusion_Transformer/Z-Image-Turbo/" |
| CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "config" / "z_image" / "z_image.yaml" |
| TRANSFORMER_CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "pulsar2_configs" / "transformers_subgraph.json" |
| TRANSFORMER_ONNX_DIR = REPO_ROOT / "VideoX-Fun" / "compiled_subgraph_from_onnx" / "frontend" |
| VAE_DECODER_ONNX = REPO_ROOT / "VideoX-Fun" / "compiled_output_vae_decoder" / "frontend" / "optimized.onnx" |
| SAVE_DIR = REPO_ROOT / "VideoX-Fun" / "samples" / "z-image-t2i-onnx" |
|
|
| |
| |
| |
| DEFAULT_PROMPTS = [ |
| "(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading", |
| "nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot", |
| "sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape", |
| "modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render", |
| "classical oil painting of a renaissance noblewoman, chiaroscuro lighting, rich textures", |
| "macro photography of a dewdrop on a leaf, extreme detail, shallow depth of field", |
| "futuristic sports car parked under neon lights, glossy paint, cinematic 35mm look", |
| "ancient library with towering bookshelves, warm candlelight, dust motes in air", |
| "portrait of an astronaut in full suit, visor reflection showing earth, studio lighting", |
| "stormy sea with a lone lighthouse, crashing waves, dramatic clouds, long exposure feel", |
| "cybernetic samurai standing in rain, backlit silhouette, moody blue-orange palette", |
| "lush rainforest waterfall, soft mist, saturated greens, wide-angle composition", |
| "product shot of a smartwatch on marble, softbox lighting, crisp shadows, advertisement style", |
| "architectural exterior of a glass skyscraper at dusk, warm interior lights, reflections", |
| "vintage film photograph of a 1950s diner at night, grain and halation, neon signage", |
| "hyperrealistic bowl of ramen, steam rising, glossy broth, detailed toppings", |
| "fantasy castle on a floating island, waterfalls falling into clouds, sunset lighting", |
| "high-fashion editorial portrait, dramatic chiaroscuro, sharp focus on eyes", |
| "aerial view of winding river through autumn forest, golden and crimson leaves", |
| "studio shot of running shoes mid-air, motion blur trails, vibrant background gradient", |
| "noir city alley in the 1940s, hard shadows, rain-slick pavement, moody atmosphere", |
| "desert caravan at twilight, silhouettes of camels, soft purple sky, cinematic scope", |
| "close-up of a mechanical watch movement, intricate gears, metallic reflections", |
| "bioluminescent underwater reef, glowing corals, schools of fish, deep blue tones", |
| "portrait of an elderly man with weathered face, soft window light, fine skin detail", |
| "snowy village at night, warm cabin lights, smoke from chimneys, peaceful mood", |
| "futuristic data center aisle, cool cyan lighting, depth and symmetry", |
| "oil painting of a bowl of fruit in Dutch masters style, rich textures, dramatic lighting", |
| "sunlit meadow with wildflowers, shallow depth of field, pastel color palette", |
| "sci-fi corridor with volumetric light shafts, pristine white surfaces, wide lens", |
| "luxury wristwatch on black velvet, high contrast, advertisement macro shot", |
| "medieval marketplace at dawn, merchants setting up, soft warm light, lively details", |
| "((masterpiece,best quality))1 young beautiful girl,ultra detailed,official art,unity 8k wallpaper,masterpiece, best quality, official art, extremely detailed CG unity 8k wallpaper, highly detailed, 1 girl, aqua eyes, light smile, ((grey hair)), hair flower, bracelet, choker, ribbon, JK, look at viewer, on the beach, in summer,", |
| ] |
| idx = random.randint(0, len(DEFAULT_PROMPTS) - 1) |
| PROMPT = DEFAULT_PROMPTS[idx] |
| NEG_PROMPT = " " |
| GUIDANCE_SCALE = 0.0 |
| SEED = 42 |
| HEIGHT, WIDTH = 512, 512 |
| NUM_INFERENCE_STEPS = 9 |
| NUM_CHANNELS_LATENTS = 16 |
| VAE_SCALE_FACTOR = 8 |
| PATCH_SIZE = 2 |
| FPATCH_SIZE = 1 |
| MAX_SEQ_LEN = 128 |
| VAE_SCALING_FACTOR = 0.3611 |
| VAE_SHIFT_FACTOR = 0.1159 |
|
|
| SAMPLER_MAP = { |
| "Flow": FlowMatchEulerDiscreteScheduler, |
| "Flow_Unipc": FlowUniPCMultistepScheduler, |
| "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| } |
| SAMPLER_NAME = "Flow" |
|
|
| |
| |
| |
|
|
| def _infer_module_device(module: torch.nn.Module) -> torch.device: |
| param = next(module.parameters(), None) |
| if param is not None: |
| return param.device |
| buffer = next(module.buffers(), None) |
| if buffer is not None: |
| return buffer.device |
| return torch.device("cpu") |
|
|
|
|
| @contextmanager |
| def module_to_device(module: torch.nn.Module, target_device: torch.device): |
| if module is None: |
| yield module |
| return |
| original_device = _infer_module_device(module) |
| target_device = target_device or original_device |
| needs_move = original_device != target_device |
| moved_to_cuda = needs_move and target_device.type == "cuda" |
| if needs_move: |
| module.to(target_device) |
| try: |
| yield module |
| finally: |
| if needs_move: |
| module.to(original_device) |
| if moved_to_cuda and torch.cuda.is_available(): |
| cache_device = target_device.index or torch.cuda.current_device() |
| with torch.cuda.device(cache_device): |
| torch.cuda.empty_cache() |
|
|
|
|
| def _encode_prompt( |
| tokenizer: AutoTokenizer, |
| text_encoder: Qwen3ForCausalLM, |
| prompt: Union[str, List[str]], |
| device: torch.device, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| max_sequence_length: int = 512, |
| ) -> List[torch.FloatTensor]: |
| if prompt_embeds is not None: |
| return prompt_embeds |
| prompts = [prompt] if isinstance(prompt, str) else list(prompt) |
| for idx, item in enumerate(prompts): |
| messages = [{"role": "user", "content": item}] |
| prompts[idx] = tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True, enable_thinking=True |
| ) |
| text_inputs = tokenizer( |
| prompts, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
| text_input_ids = text_inputs.input_ids.to(device) |
| prompt_masks = text_inputs.attention_mask.to(device).bool() |
| with module_to_device(text_encoder, device): |
| prompt_embeds = text_encoder( |
| input_ids=text_input_ids, |
| attention_mask=prompt_masks, |
| output_hidden_states=True, |
| ).hidden_states[-2] |
| return [prompt_embeds[i] for i in range(len(prompt_embeds))] |
|
|
|
|
| def encode_prompt( |
| tokenizer: AutoTokenizer, |
| text_encoder: Qwen3ForCausalLM, |
| prompt: Union[str, List[str]], |
| device: torch.device, |
| do_classifier_free_guidance: bool, |
| negative_prompt: Optional[Union[str, List[str]]], |
| max_sequence_length: int, |
| ) -> Tuple[List[torch.FloatTensor], List[torch.FloatTensor]]: |
| prompt_embeds = _encode_prompt( |
| tokenizer, text_encoder, prompt, device, None, max_sequence_length |
| ) |
| negative_embeds: List[torch.FloatTensor] = [] |
| if do_classifier_free_guidance: |
| neg = negative_prompt or "" |
| negative_list = [neg] if isinstance(neg, str) else list(neg) |
| negative_embeds = _encode_prompt( |
| tokenizer, text_encoder, negative_list, device, None, max_sequence_length |
| ) |
| return prompt_embeds, negative_embeds |
|
|
|
|
| def _stack_prompt_embeddings(prompt_embeds_input): |
| if isinstance(prompt_embeds_input, list): |
| return torch.stack(prompt_embeds_input, dim=0) |
| return prompt_embeds_input |
|
|
|
|
| def prepare_latents( |
| batch_size: int, |
| num_channels_latents: int, |
| height: int, |
| width: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| generator: torch.Generator, |
| ) -> torch.FloatTensor: |
| height = 2 * (int(height) // (VAE_SCALE_FACTOR * 2)) |
| width = 2 * (int(width) // (VAE_SCALE_FACTOR * 2)) |
| shape = (batch_size, num_channels_latents, height, width) |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| return latents |
|
|
|
|
| def calculate_shift( |
| image_seq_len: int, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ) -> float: |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| return image_seq_len * m + b |
|
|
|
|
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: int, |
| device: torch.device, |
| **kwargs, |
| ): |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| return scheduler.timesteps |
|
|
|
|
| def resolve_providers(use_cuda: bool) -> List[str]: |
| available = ort.get_available_providers() |
| if use_cuda and "CUDAExecutionProvider" in available: |
| return ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| return ["CPUExecutionProvider"] |
|
|
|
|
| |
| |
| |
|
|
|
|
| class OnnxSplitTransformer: |
| def __init__(self, config_path: Path, model_dir: Path, providers: List[str]): |
| self.config_path = config_path |
| self.model_dir = model_dir |
| self.providers = providers |
| self.specs = self._load_specs() |
| self.sessions, self.default_final_output = self._load_sessions_and_outputs() |
|
|
| def _load_specs(self) -> List[SubGraphSpec]: |
| with self.config_path.open("r", encoding="utf-8") as f: |
| config = json.load(f) |
| sub_configs = config.get("compiler", {}).get("sub_configs", []) |
| if not sub_configs: |
| raise ValueError("配置文件缺少 compiler.sub_configs") |
|
|
| specs: List[SubGraphSpec] = [] |
| for idx, entry in enumerate(sub_configs): |
| start = [name for name in entry.get("start_tensor_names", []) if name] |
| end = [name for name in entry.get("end_tensor_names", []) if name] |
| if not start or not end: |
| raise ValueError(f"sub_config[{idx}] 定义不完整") |
| specs.append( |
| SubGraphSpec( |
| label=f"cfg_{idx:02d}", |
| start=start, |
| end=end, |
| node_names=set(), |
| source="config", |
| ) |
| ) |
| return specs |
|
|
| def _expected_path(self, spec: SubGraphSpec) -> Path: |
| head = sanitize(spec.start[0]) if spec.start else "const" |
| tail = sanitize(spec.end[0]) if spec.end else "out" |
| filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.onnx" |
| path = self.model_dir / filename |
| if not path.exists(): |
| raise FileNotFoundError(f"缺少 ONNX 模型: {path}") |
| return path |
|
|
| def _load_sessions_and_outputs(self) -> Tuple[List[Tuple[SubGraphSpec, ort.InferenceSession]], str]: |
| sessions: List[Tuple[SubGraphSpec, ort.InferenceSession]] = [] |
| |
| for spec in self.specs: |
| path = self._expected_path(spec) |
| sess = ort.InferenceSession(path.as_posix(), providers=self.providers) |
| sessions.append((spec, sess)) |
|
|
| |
| auto_specs: List[SubGraphSpec] = [] |
| for path in sorted(self.model_dir.glob("auto_*.onnx")): |
| sess = ort.InferenceSession(path.as_posix(), providers=self.providers) |
| inputs = [i.name for i in sess.get_inputs()] |
| outputs = [o.name for o in sess.get_outputs()] |
| auto_spec = SubGraphSpec( |
| label=path.stem, |
| start=inputs, |
| end=outputs, |
| node_names=set(), |
| source="auto", |
| ) |
| auto_specs.append(auto_spec) |
| sessions.append((auto_spec, sess)) |
|
|
| |
| if auto_specs: |
| default_output = auto_specs[-1].end[0] |
| else: |
| default_output = self.specs[-1].end[0] |
|
|
| |
| self.specs = self.specs + auto_specs |
| return sessions, default_output |
|
|
| def __call__( |
| self, |
| latent_np: np.ndarray, |
| prompt_np: np.ndarray, |
| timestep_np: np.ndarray, |
| final_output_name: Optional[str] = None, |
| ) -> np.ndarray: |
| tensor_store: Dict[str, np.ndarray] = { |
| "latent_model_input": latent_np, |
| "prompt_embeds": prompt_np, |
| "timestep": timestep_np, |
| } |
| executed = set() |
| target = final_output_name or self.default_final_output |
|
|
| |
| while target not in tensor_store: |
| progressed = False |
| for spec, session in self.sessions: |
| if spec.label in executed: |
| continue |
| if not all(name in tensor_store for name in spec.start): |
| continue |
| inputs = {name: tensor_store[name] for name in spec.start} |
| outputs = session.run(spec.end, inputs) |
| for out_name, value in zip(spec.end, outputs): |
| tensor_store[out_name] = value |
| executed.add(spec.label) |
| progressed = True |
| if not progressed: |
| missing = [ |
| (spec.label, [name for name in spec.start if name not in tensor_store]) |
| for spec, _ in self.sessions |
| if spec.label not in executed |
| ] |
| raise RuntimeError( |
| f"子图调度中断,缺少输入: {missing}; 当前可用: {list(tensor_store.keys())}" |
| ) |
|
|
| return tensor_store[target] |
|
|
|
|
| |
| |
| |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description="ONNXRuntime 推理 (transformer + VAE)") |
| parser.add_argument("--prompt", type=str, default=None, help="正向提示词,不填则使用预置随机样本") |
| parser.add_argument("--negative-prompt", type=str, default=NEG_PROMPT, help="反向提示词") |
| parser.add_argument("--steps", type=int, default=NUM_INFERENCE_STEPS, help="迭代步数") |
| parser.add_argument("--height", type=int, default=HEIGHT, help="生成高度,需被 16 整除") |
| parser.add_argument("--width", type=int, default=WIDTH, help="生成宽度,需被 16 整除") |
| parser.add_argument("--seed", type=int, default=SEED, help="随机种子") |
| parser.add_argument("--sampler", type=str, choices=list(SAMPLER_MAP.keys()), default=SAMPLER_NAME, help="采样器") |
| parser.add_argument("--max-seq-len", type=int, default=MAX_SEQ_LEN, help="最大文本长度") |
| parser.add_argument("--save-dir", type=str, default=str(SAVE_DIR), help="结果输出目录") |
| parser.add_argument("--transformer-config", type=str, default=str(TRANSFORMER_CONFIG_PATH), help="子图配置 json") |
| parser.add_argument("--transformer-subgraph-dir", type=str, default=str(TRANSFORMER_ONNX_DIR), help="子图 onnx 目录") |
| parser.add_argument("--vae-onnx", type=str, default=str(VAE_DECODER_ONNX), help="VAE decoder onnx 路径") |
| parser.add_argument("--use-cuda-provider", action="store_true", help="优先使用 CUDAExecutionProvider") |
| parser.add_argument("--save-decoder-input", action="store_true", help="是否保存 decoder 输入 npy") |
| parser.add_argument("--final-output-name", type=str, default=None, help="指定最终输出 tensor 名称,默认为最后一个子图的第一个输出") |
| parser.add_argument("--no-progress", action="store_true", help="关闭进度条输出") |
| return parser.parse_args() |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| prompt_text = args.prompt if args.prompt is not None else PROMPT |
| logger.info(f"使用的 prompt: {prompt_text}") |
|
|
| device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| torch.set_grad_enabled(False) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer") |
| text_encoder = Qwen3ForCausalLM.from_pretrained( |
| MODEL_NAME, |
| subfolder="text_encoder", |
| torch_dtype=torch.bfloat16, |
| low_cpu_mem_usage=True, |
| ) |
| text_encoder.eval() |
|
|
| scheduler_cls = SAMPLER_MAP[args.sampler] |
| scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler") |
|
|
| image_processor = VaeImageProcessor(vae_scale_factor=VAE_SCALE_FACTOR * 2) |
|
|
| prompt_embeds, _ = encode_prompt( |
| tokenizer, |
| text_encoder, |
| prompt_text, |
| device, |
| do_classifier_free_guidance=False, |
| negative_prompt=args.negative_prompt, |
| max_sequence_length=args.max_seq_len, |
| ) |
|
|
| latents = prepare_latents( |
| batch_size=1, |
| num_channels_latents=NUM_CHANNELS_LATENTS, |
| height=args.height, |
| width=args.width, |
| dtype=torch.float32, |
| device=device, |
| generator=torch.Generator(device=device).manual_seed(args.seed), |
| ) |
|
|
| image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
| mu = calculate_shift( |
| image_seq_len, |
| scheduler.config.get("base_image_seq_len", 256), |
| scheduler.config.get("max_image_seq_len", 4096), |
| scheduler.config.get("base_shift", 0.5), |
| scheduler.config.get("max_shift", 1.15), |
| ) |
| timesteps = retrieve_timesteps(scheduler, args.steps, device=device, mu=mu) |
|
|
| providers = resolve_providers(args.use_cuda_provider) |
| transformer_runner = OnnxSplitTransformer( |
| Path(args.transformer_config), Path(args.transformer_subgraph_dir), providers |
| ) |
| final_output_name = args.final_output_name or transformer_runner.default_final_output |
|
|
| prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds) |
|
|
| iterator = timesteps if args.no_progress else tqdm(timesteps, desc="Denoising", dynamic_ncols=True) |
|
|
| for t in iterator: |
| timestep = t.expand(latents.shape[0]) |
| timestep_model_input = (1000 - timestep) / 1000 |
|
|
| latent_model_input = latents.to(torch.float32) |
| latent_np = latent_model_input.unsqueeze(2).to(dtype=torch.float32).cpu().numpy() |
| prompt_np = prompt_embeds_tensor.to(dtype=torch.float32).cpu().numpy() |
| timestep_np = timestep_model_input.to(dtype=torch.float32).cpu().numpy() |
|
|
| model_out = transformer_runner(latent_np, prompt_np, timestep_np, final_output_name) |
| if model_out.ndim == 5 and model_out.shape[2] == 1: |
| model_out = np.squeeze(model_out, axis=2) |
| model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=torch.float32) |
| if model_out_tensor.dim() == 5 and model_out_tensor.size(2) == 1: |
| model_out_tensor = model_out_tensor.squeeze(2) |
| noise_pred = -model_out_tensor |
|
|
| latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
|
|
| latents = (latents / VAE_SCALING_FACTOR) + VAE_SHIFT_FACTOR |
| decoder_input = latents.to(dtype=torch.float32).cpu().numpy() |
|
|
| if args.save_decoder_input: |
| SAVE_DIR_PATH = Path(args.save_dir) |
| SAVE_DIR_PATH.mkdir(parents=True, exist_ok=True) |
| np.save(SAVE_DIR_PATH / "decoder_input.npy", decoder_input) |
| logger.info("已保存 decoder 输入为 npy") |
|
|
| if decoder_input.ndim == 5 and decoder_input.shape[2] == 1: |
| decoder_input = np.squeeze(decoder_input, axis=2) |
|
|
| vae_session = ort.InferenceSession(Path(args.vae_onnx).as_posix(), providers=providers) |
| image = vae_session.run(None, {"latents": decoder_input})[0] |
|
|
| image = torch.from_numpy(image).to(device=device, dtype=torch.float32) |
| image = image_processor.postprocess(image, output_type="pil") |
|
|
| save_dir = Path(args.save_dir) |
| save_dir.mkdir(parents=True, exist_ok=True) |
| target_path = save_dir / f"z_image_onnx_{idx}.png" |
| image[0].save(target_path) |
| logger.info(f"ONNXRuntime 推理完成,结果保存到 {target_path}") |
|
|
|
|
| if __name__ == "__main__": |
| """ |
| # 512x512 生成示例命令: |
| python examples/z_image_fun/launcher_onnx.py \ |
| --transformer-config pulsar2_configs/transformers_subgraph.json \ |
| --transformer-subgraph-dir transformers_body_only_split_onnx --vae-onnx onnx-models/vae_decoder_simp_slim.onnx |
| |
| # 1728x992 生成示例命令: |
| python examples/z_image_fun/launcher_onnx.py \ |
| --transformer-config pulsar2_configs/transformers_subgraph_1728x992.json \ |
| --transformer-subgraph-dir transformers_body_only_1728_992_split_onnx \ |
| --vae-onnx onnx-models-1728x992/vae_decoder_simp_slim.onnx \ |
| --max-seq-len 256 \ |
| --height 1728 --width 992 |
| """ |
| main() |
|
|