#!/usr/bin/env python3 """使用 AXModel 推理链路 (transformer + VAE decoder)。""" from __future__ import annotations import argparse import json import os import random import sys from contextlib import contextmanager from pathlib import Path from typing import Dict, Iterable, List, Optional, Tuple, Union current_file_path = os.path.abspath(__file__) project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] for project_root in project_roots: sys.path.insert(0, project_root) if project_root not in sys.path else None SCRIPT_DIR = Path(__file__).resolve().parent REPO_ROOT = SCRIPT_DIR.parents[2] if REPO_ROOT.as_posix() not in sys.path: sys.path.insert(0, REPO_ROOT.as_posix()) import numpy as np import torch from axengine import InferenceSession as AxInferenceSession from diffusers import FlowMatchEulerDiscreteScheduler from diffusers.image_processor import VaeImageProcessor from diffusers.utils.torch_utils import randn_tensor from omegaconf import OmegaConf from PIL import Image from loguru import logger from tqdm import tqdm from videox_fun.models import AutoTokenizer, Qwen3ForCausalLM from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler from videox_fun.utils.utils import get_image_latent # ----------------------------------------------------------------------------- # 模型与资源路径 # ----------------------------------------------------------------------------- MODEL_NAME = "models/Diffusion_Transformer/Z-Image-Turbo/" CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "config" / "z_image" / "z_image.yaml" TRANSFORMER_CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "pulsar2_configs" / "transformers_subgraph.json" TRANSFORMER_ONNX_PATH = REPO_ROOT / "VideoX-Fun" / "compiled_subgraph_from_onnx" / "frontend" / "optimized_quant_axmodel.onnx" TRANSFORMER_AXMODEL_DIR = REPO_ROOT / "VideoX-Fun" / "comliled_subgraph_from_all_onnx" # compiled_slice_quant_onnx VAE_DECODER_AXMODEL = REPO_ROOT / "VideoX-Fun" / "vae_decoder.axmodel" SAVE_DIR = REPO_ROOT / "VideoX-Fun" / "samples" / "z-image-t2i-axmodel" # ----------------------------------------------------------------------------- # 运行配置 # ----------------------------------------------------------------------------- DEFAULT_PROMPTS = [ "(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading", "nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot", "sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape", "modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render", "classical oil painting of a renaissance noblewoman, chiaroscuro lighting, rich textures", "macro photography of a dewdrop on a leaf, extreme detail, shallow depth of field", "futuristic sports car parked under neon lights, glossy paint, cinematic 35mm look", "ancient library with towering bookshelves, warm candlelight, dust motes in air", "portrait of an astronaut in full suit, visor reflection showing earth, studio lighting", "stormy sea with a lone lighthouse, crashing waves, dramatic clouds, long exposure feel", "cybernetic samurai standing in rain, backlit silhouette, moody blue-orange palette", "lush rainforest waterfall, soft mist, saturated greens, wide-angle composition", "product shot of a smartwatch on marble, softbox lighting, crisp shadows, advertisement style", "architectural exterior of a glass skyscraper at dusk, warm interior lights, reflections", "vintage film photograph of a 1950s diner at night, grain and halation, neon signage", "hyperrealistic bowl of ramen, steam rising, glossy broth, detailed toppings", "fantasy castle on a floating island, waterfalls falling into clouds, sunset lighting", "high-fashion editorial portrait, dramatic chiaroscuro, sharp focus on eyes", "aerial view of winding river through autumn forest, golden and crimson leaves", "studio shot of running shoes mid-air, motion blur trails, vibrant background gradient", "noir city alley in the 1940s, hard shadows, rain-slick pavement, moody atmosphere", "desert caravan at twilight, silhouettes of camels, soft purple sky, cinematic scope", "close-up of a mechanical watch movement, intricate gears, metallic reflections", "bioluminescent underwater reef, glowing corals, schools of fish, deep blue tones", "portrait of an elderly man with weathered face, soft window light, fine skin detail", "snowy village at night, warm cabin lights, smoke from chimneys, peaceful mood", "futuristic data center aisle, cool cyan lighting, depth and symmetry", "oil painting of a bowl of fruit in Dutch masters style, rich textures, dramatic lighting", "sunlit meadow with wildflowers, shallow depth of field, pastel color palette", "sci-fi corridor with volumetric light shafts, pristine white surfaces, wide lens", "luxury wristwatch on black velvet, high contrast, advertisement macro shot", "medieval marketplace at dawn, merchants setting up, soft warm light, lively details", "((masterpiece,best quality))1 young beautiful girl,ultra detailed,official art,unity 8k wallpaper,masterpiece, best quality, official art, extremely detailed CG unity 8k wallpaper, highly detailed, 1 girl, aqua eyes, light smile, ((grey hair)), hair flower, bracelet, choker, ribbon, JK, look at viewer, on the beach, in summer," ] prompt_idx = random.randint(0, len(DEFAULT_PROMPTS) - 1) PROMPT = DEFAULT_PROMPTS[prompt_idx] NEG_PROMPT = " " GUIDANCE_SCALE = 0.0 SEED = 42 HEIGHT, WIDTH = 512, 512 NUM_INFERENCE_STEPS = 9 NUM_CHANNELS_LATENTS = 16 VAE_SCALE_FACTOR = 8 PATCH_SIZE = 2 FPATCH_SIZE = 1 MAX_SEQ_LEN = 128 VAE_SCALING_FACTOR = 0.3611 VAE_SHIFT_FACTOR = 0.1159 SAMPLER_MAP = { "Flow": FlowMatchEulerDiscreteScheduler, "Flow_Unipc": FlowUniPCMultistepScheduler, "Flow_DPM++": FlowDPMSolverMultistepScheduler, } SAMPLER_NAME = "Flow" # 默认最终输出,如果不存在 auto 子图则回退到最后一个 cfg 输出 DEFAULT_FINAL_OUTPUT = None # ----------------------------------------------------------------------------- # 工具函数 (复制自原 launcher 并微调) # ----------------------------------------------------------------------------- def _infer_module_device(module: torch.nn.Module) -> torch.device: param = next(module.parameters(), None) if param is not None: return param.device buffer = next(module.buffers(), None) if buffer is not None: return buffer.device return torch.device("cpu") @contextmanager def module_to_device(module: torch.nn.Module, target_device: torch.device): if module is None: yield module return original_device = _infer_module_device(module) target_device = target_device or original_device needs_move = original_device != target_device moved_to_cuda = needs_move and target_device.type == "cuda" if needs_move: module.to(target_device) try: yield module finally: if needs_move: module.to(original_device) if moved_to_cuda and torch.cuda.is_available(): cache_device = target_device.index or torch.cuda.current_device() with torch.cuda.device(cache_device): torch.cuda.empty_cache() def _encode_prompt( tokenizer: AutoTokenizer, text_encoder: Qwen3ForCausalLM, prompt: Union[str, List[str]], device: torch.device, prompt_embeds: Optional[List[torch.FloatTensor]] = None, max_sequence_length: int = 512, ) -> List[torch.FloatTensor]: if prompt_embeds is not None: return prompt_embeds prompts = [prompt] if isinstance(prompt, str) else list(prompt) for idx, item in enumerate(prompts): messages = [{"role": "user", "content": item}] prompts[idx] = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, enable_thinking=True ) text_inputs = tokenizer( prompts, padding="max_length", max_length=max_sequence_length, truncation=True, return_tensors="pt", ) text_input_ids = text_inputs.input_ids.to(device) prompt_masks = text_inputs.attention_mask.to(device).bool() with module_to_device(text_encoder, device): prompt_embeds = text_encoder( input_ids=text_input_ids, attention_mask=prompt_masks, output_hidden_states=True, ).hidden_states[-2] return [prompt_embeds[i] for i in range(len(prompt_embeds))] def encode_prompt( tokenizer: AutoTokenizer, text_encoder: Qwen3ForCausalLM, prompt: Union[str, List[str]], device: torch.device, do_classifier_free_guidance: bool, negative_prompt: Optional[Union[str, List[str]]], max_sequence_length: int, ) -> Tuple[List[torch.FloatTensor], List[torch.FloatTensor]]: prompt_embeds = _encode_prompt( tokenizer, text_encoder, prompt, device, None, max_sequence_length ) negative_embeds: List[torch.FloatTensor] = [] if do_classifier_free_guidance: neg = negative_prompt or "" negative_list = [neg] if isinstance(neg, str) else list(neg) negative_embeds = _encode_prompt( tokenizer, text_encoder, negative_list, device, None, max_sequence_length ) return prompt_embeds, negative_embeds def _stack_prompt_embeddings(prompt_embeds_input): if isinstance(prompt_embeds_input, list): return torch.stack(prompt_embeds_input, dim=0) return prompt_embeds_input def prepare_latents( batch_size: int, num_channels_latents: int, height: int, width: int, dtype: torch.dtype, device: torch.device, generator: torch.Generator, ) -> torch.FloatTensor: height = 2 * (int(height) // (VAE_SCALE_FACTOR * 2)) width = 2 * (int(width) // (VAE_SCALE_FACTOR * 2)) shape = (batch_size, num_channels_latents, height, width) latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) return latents def calculate_shift( image_seq_len: int, base_seq_len: int = 256, max_seq_len: int = 4096, base_shift: float = 0.5, max_shift: float = 1.15, ) -> float: m = (max_shift - base_shift) / (max_seq_len - base_seq_len) b = base_shift - m * base_seq_len return image_seq_len * m + b def retrieve_timesteps( scheduler, num_inference_steps: int, device: torch.device, **kwargs, ): scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) return scheduler.timesteps # ----------------------------------------------------------------------------- # 参数 # ----------------------------------------------------------------------------- def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="AXModel 推理 (transformer + VAE)") parser.add_argument("--prompt", type=str, default=None, help="正向提示词,不填则使用预置随机样本") parser.add_argument("--negative-prompt", type=str, default=NEG_PROMPT, help="反向提示词") parser.add_argument("--steps", type=int, default=NUM_INFERENCE_STEPS, help="迭代步数") parser.add_argument("--height", type=int, default=HEIGHT, help="生成高度,需被 16 整除") parser.add_argument("--width", type=int, default=WIDTH, help="生成宽度,需被 16 整除") parser.add_argument("--seed", type=int, default=SEED, help="随机种子") parser.add_argument("--sampler", type=str, choices=list(SAMPLER_MAP.keys()), default=SAMPLER_NAME, help="采样器") parser.add_argument("--max-seq-len", type=int, default=MAX_SEQ_LEN, help="最大文本长度") parser.add_argument("--save-dir", type=str, default=str(SAVE_DIR), help="结果输出目录") parser.add_argument("--transformer-config", type=str, required=True, help="子图配置 json") parser.add_argument("--transformer-onnx", type=str, default=None, help="原始 transformer onnx(可选,sub_configs 已覆盖可不填)") parser.add_argument("--transformer-subgraph-dir", type=str, required=True, help="子图 axmodel 目录") parser.add_argument("--vae-axmodel", type=str, required=True, help="VAE decoder axmodel 路径") parser.add_argument("--final-output-name", type=str, default=None, help="指定最终输出 tensor 名称,默认自动推断") parser.add_argument("--save-decoder-input", action="store_true", help="是否保存 decoder 输入 npy") parser.add_argument("--no-progress", action="store_true", help="关闭进度条输出") return parser.parse_args() # ----------------------------------------------------------------------------- # AX transformer 子图执行器 # ----------------------------------------------------------------------------- from scripts.split_quant_onnx_by_subconfigs import SubGraphSpec, sanitize class AxSplitTransformer: def __init__(self, config_path: Path, onnx_path: Optional[Path], model_dir: Path): self.config_path = config_path self.onnx_path = onnx_path self.model_dir = model_dir self._session_cache: Dict[str, AxInferenceSession] = {} config_specs = self._load_specs() auto_specs = self._load_auto_specs() self.specs = config_specs + auto_specs last_group = auto_specs if auto_specs else config_specs self.final_outputs = list(last_group[-1].end) self.default_final_output = DEFAULT_FINAL_OUTPUT or self.final_outputs[0] def _get_session(self, spec: SubGraphSpec) -> AxInferenceSession: if spec.label not in self._session_cache: path = self._expected_path(spec) self._session_cache[spec.label] = AxInferenceSession(path.as_posix()) logger.info(f"加载子图 session: {spec.label} from {path.name}") return self._session_cache[spec.label] def close(self) -> None: # 显式释放缓存的 session for key, sess in list(self._session_cache.items()): try: del sess finally: self._session_cache.pop(key, None) def _load_specs(self) -> List[SubGraphSpec]: with self.config_path.open("r", encoding="utf-8") as f: config = json.load(f) sub_configs = config.get("compiler", {}).get("sub_configs", []) if not sub_configs: raise ValueError("配置文件缺少 compiler.sub_configs") specs: List[SubGraphSpec] = [] for idx, entry in enumerate(sub_configs): start = [name for name in entry.get("start_tensor_names", []) if name] end = [name for name in entry.get("end_tensor_names", []) if name] if not start or not end: raise ValueError(f"sub_config[{idx}] 定义不完整") spec = SubGraphSpec( label=f"cfg_{idx:02d}", start=start, end=end, node_names=set(), source="config", ) specs.append(spec) return specs def _load_auto_specs(self) -> List[SubGraphSpec]: specs: List[SubGraphSpec] = [] for path in sorted(self.model_dir.glob("auto_*.axmodel")): try: session = AxInferenceSession(path.as_posix()) inputs = [info.name for info in session.get_inputs() if getattr(info, "name", None)] outputs = [info.name for info in session.get_outputs() if getattr(info, "name", None)] # 缓存 session,避免重复打开 self._session_cache[path.stem] = session except Exception as exc: # pragma: no cover - defensive logger.warning(f"跳过 {path.name},加载/解析 IO 失败: {exc}") continue if not inputs or not outputs: logger.warning(f"跳过 {path.name},未找到有效的输入/输出定义") continue specs.append( SubGraphSpec( label=path.stem, start=inputs, end=outputs, node_names=set(), source="auto", output_path=path, ) ) return specs def _expected_path(self, spec: SubGraphSpec) -> Path: if spec.output_path is not None: path = spec.output_path else: head = sanitize(spec.start[0]) if spec.start else "const" tail = sanitize(spec.end[0]) if spec.end else "out" filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.axmodel" path = self.model_dir / filename if not path.exists(): raise FileNotFoundError(f"缺少 AXModel: {path}") return path def __call__( self, latent_np: np.ndarray, prompt_np: np.ndarray, timestep_np: np.ndarray, final_output_name: Optional[str] = None, ) -> np.ndarray: tensor_store: Dict[str, np.ndarray] = { "latent_model_input": latent_np, "prompt_embeds": prompt_np, "timestep": timestep_np, } executed = set() target = final_output_name or self.default_final_output # 就绪驱动执行,单个子图跑完立刻释放 session while target not in tensor_store: progressed = False for spec in self.specs: if spec.label in executed: continue if not all(name in tensor_store for name in spec.start): continue session = self._get_session(spec) inputs = {name: tensor_store[name] for name in spec.start} outputs = session.run(spec.end, inputs) for out_name, value in zip(spec.end, outputs): tensor_store[out_name] = value executed.add(spec.label) progressed = True if not progressed: missing = [ (spec.label, [name for name in spec.start if name not in tensor_store]) for spec in self.specs if spec.label not in executed ] raise RuntimeError( f"子图调度中断,缺少输入: {missing}; 当前可用: {list(tensor_store.keys())}" ) return tensor_store[target] # ----------------------------------------------------------------------------- # 主流程 # ----------------------------------------------------------------------------- def main() -> None: args = parse_args() prompt_text = args.prompt if args.prompt is not None else PROMPT logger.info(f"使用的 prompt: {prompt_text}") device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") torch.set_grad_enabled(False) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer") text_encoder = Qwen3ForCausalLM.from_pretrained( MODEL_NAME, subfolder="text_encoder", torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, ) text_encoder.eval() scheduler_cls = SAMPLER_MAP[args.sampler] scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler") image_processor = VaeImageProcessor(vae_scale_factor=VAE_SCALE_FACTOR * 2) prompt_embeds, _ = encode_prompt( tokenizer, text_encoder, prompt_text, device, do_classifier_free_guidance=False, negative_prompt=args.negative_prompt, max_sequence_length=args.max_seq_len, ) latents = prepare_latents( batch_size=1, num_channels_latents=NUM_CHANNELS_LATENTS, height=args.height, width=args.width, dtype=torch.float32, device=device, generator=torch.Generator(device=device).manual_seed(args.seed), ) image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) mu = calculate_shift( image_seq_len, scheduler.config.get("base_image_seq_len", 256), scheduler.config.get("max_image_seq_len", 4096), scheduler.config.get("base_shift", 0.5), scheduler.config.get("max_shift", 1.15), ) timesteps = retrieve_timesteps(scheduler, args.steps, device=device, mu=mu) onnx_path = Path(args.transformer_onnx) if args.transformer_onnx else None transformer_runner = AxSplitTransformer( Path(args.transformer_config), onnx_path, Path(args.transformer_subgraph_dir), ) # 优先使用 auto_* 子图里的 sample 输出,避免误用中间特征导致 shape 对不上 available_outputs = [name for spec in transformer_runner.specs for name in getattr(spec, "end", [])] preferred_output = "sample" if "sample" in available_outputs else transformer_runner.default_final_output final_output_name = args.final_output_name or preferred_output if final_output_name not in available_outputs: raise ValueError(f"指定的输出 {final_output_name} 不存在,可选: {available_outputs}") prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds) iterator = timesteps if args.no_progress else tqdm(timesteps, desc="AX Denoising", dynamic_ncols=True) for t in iterator: timestep = t.expand(latents.shape[0]) timestep_model_input = (1000 - timestep) / 1000 latent_model_input = latents.to(torch.float32) latent_np = latent_model_input.unsqueeze(2).to(dtype=torch.float32).cpu().numpy() prompt_np = prompt_embeds_tensor.to(dtype=torch.float32).cpu().numpy() timestep_np = timestep_model_input.to(dtype=torch.float32).cpu().numpy() model_out = transformer_runner(latent_np, prompt_np, timestep_np, final_output_name) if model_out.ndim == 5 and model_out.shape[2] == 1: model_out = np.squeeze(model_out, axis=2) model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=torch.float32) if model_out_tensor.dim() == 5 and model_out_tensor.size(2) == 1: model_out_tensor = model_out_tensor.squeeze(2) noise_pred = -model_out_tensor latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] # 释放 transformer 缓存的 session transformer_runner.close() latents = (latents / VAE_SCALING_FACTOR) + VAE_SHIFT_FACTOR decoder_input = latents.to(dtype=torch.float32).cpu().numpy() if args.save_decoder_input: save_dir_path = Path(args.save_dir) save_dir_path.mkdir(parents=True, exist_ok=True) np.save(save_dir_path / "decoder_input.npy", decoder_input) logger.info("已保存 decoder 输入为 npy") del transformer_runner vae_decoder_session = AxInferenceSession(Path(args.vae_axmodel).as_posix()) if decoder_input.ndim == 5 and decoder_input.shape[2] == 1: decoder_input = np.squeeze(decoder_input, axis=2) image = vae_decoder_session.run(None, {"latents": decoder_input})[0] del vae_decoder_session image = torch.from_numpy(image).to(device=device, dtype=torch.float32) image = image_processor.postprocess(image, output_type="pil") save_dir = Path(args.save_dir) save_dir.mkdir(parents=True, exist_ok=True) target_path = save_dir / f"z_image_axmodel_{prompt_idx}.png" image[0].save(target_path) logger.info(f"AXModel 推理完成,结果保存到 {target_path}") if __name__ == "__main__": """ # 512x512 生成示例命令: python3 examples/z_image_fun/launcher_axmodel.py \ --transformer-config pulsar2_configs/transformers_subgraph.json \ --transformer-subgraph-dir comliled_subgraph_from_all_onnx \ --vae-axmodel vae_decoder.axmodel # 1728x992 生成示例命令: python3 examples/z_image_fun/launcher_axmodel.py \ --transformer-config pulsar2_configs/transformers_subgraph_1728x992.json \ --transformer-subgraph-dir transformers_body_only_1728_992_split_onnx \ --vae-axmodel onnx-models-1728x992/vae_decoder_simp_slim.axmodel \ --max-seq-len 256 \ --height 1728 --width 992 """ main()