yongqiang
launcher_axmodel 移除对编译上下文的依赖
83c76bb
#!/usr/bin/env python3
"""使用 AXModel 推理链路 (transformer + VAE decoder)。"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from contextlib import contextmanager
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Tuple, Union
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
SCRIPT_DIR = Path(__file__).resolve().parent
REPO_ROOT = SCRIPT_DIR.parents[2]
if REPO_ROOT.as_posix() not in sys.path:
sys.path.insert(0, REPO_ROOT.as_posix())
import numpy as np
import torch
from axengine import InferenceSession as AxInferenceSession
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils.torch_utils import randn_tensor
from omegaconf import OmegaConf
from PIL import Image
from loguru import logger
from tqdm import tqdm
from videox_fun.models import AutoTokenizer, Qwen3ForCausalLM
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.utils.utils import get_image_latent
# -----------------------------------------------------------------------------
# 模型与资源路径
# -----------------------------------------------------------------------------
MODEL_NAME = "models/Diffusion_Transformer/Z-Image-Turbo/"
CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "config" / "z_image" / "z_image.yaml"
TRANSFORMER_CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "pulsar2_configs" / "transformers_subgraph.json"
TRANSFORMER_ONNX_PATH = REPO_ROOT / "VideoX-Fun" / "compiled_subgraph_from_onnx" / "frontend" / "optimized_quant_axmodel.onnx"
TRANSFORMER_AXMODEL_DIR = REPO_ROOT / "VideoX-Fun" / "comliled_subgraph_from_all_onnx" # compiled_slice_quant_onnx
VAE_DECODER_AXMODEL = REPO_ROOT / "VideoX-Fun" / "vae_decoder.axmodel"
SAVE_DIR = REPO_ROOT / "VideoX-Fun" / "samples" / "z-image-t2i-axmodel"
# -----------------------------------------------------------------------------
# 运行配置
# -----------------------------------------------------------------------------
DEFAULT_PROMPTS = [
"(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading",
"nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot",
"sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape",
"modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render",
"classical oil painting of a renaissance noblewoman, chiaroscuro lighting, rich textures",
"macro photography of a dewdrop on a leaf, extreme detail, shallow depth of field",
"futuristic sports car parked under neon lights, glossy paint, cinematic 35mm look",
"ancient library with towering bookshelves, warm candlelight, dust motes in air",
"portrait of an astronaut in full suit, visor reflection showing earth, studio lighting",
"stormy sea with a lone lighthouse, crashing waves, dramatic clouds, long exposure feel",
"cybernetic samurai standing in rain, backlit silhouette, moody blue-orange palette",
"lush rainforest waterfall, soft mist, saturated greens, wide-angle composition",
"product shot of a smartwatch on marble, softbox lighting, crisp shadows, advertisement style",
"architectural exterior of a glass skyscraper at dusk, warm interior lights, reflections",
"vintage film photograph of a 1950s diner at night, grain and halation, neon signage",
"hyperrealistic bowl of ramen, steam rising, glossy broth, detailed toppings",
"fantasy castle on a floating island, waterfalls falling into clouds, sunset lighting",
"high-fashion editorial portrait, dramatic chiaroscuro, sharp focus on eyes",
"aerial view of winding river through autumn forest, golden and crimson leaves",
"studio shot of running shoes mid-air, motion blur trails, vibrant background gradient",
"noir city alley in the 1940s, hard shadows, rain-slick pavement, moody atmosphere",
"desert caravan at twilight, silhouettes of camels, soft purple sky, cinematic scope",
"close-up of a mechanical watch movement, intricate gears, metallic reflections",
"bioluminescent underwater reef, glowing corals, schools of fish, deep blue tones",
"portrait of an elderly man with weathered face, soft window light, fine skin detail",
"snowy village at night, warm cabin lights, smoke from chimneys, peaceful mood",
"futuristic data center aisle, cool cyan lighting, depth and symmetry",
"oil painting of a bowl of fruit in Dutch masters style, rich textures, dramatic lighting",
"sunlit meadow with wildflowers, shallow depth of field, pastel color palette",
"sci-fi corridor with volumetric light shafts, pristine white surfaces, wide lens",
"luxury wristwatch on black velvet, high contrast, advertisement macro shot",
"medieval marketplace at dawn, merchants setting up, soft warm light, lively details",
"((masterpiece,best quality))1 young beautiful girl,ultra detailed,official art,unity 8k wallpaper,masterpiece, best quality, official art, extremely detailed CG unity 8k wallpaper, highly detailed, 1 girl, aqua eyes, light smile, ((grey hair)), hair flower, bracelet, choker, ribbon, JK, look at viewer, on the beach, in summer,"
]
prompt_idx = random.randint(0, len(DEFAULT_PROMPTS) - 1)
PROMPT = DEFAULT_PROMPTS[prompt_idx]
NEG_PROMPT = " "
GUIDANCE_SCALE = 0.0
SEED = 42
HEIGHT, WIDTH = 512, 512
NUM_INFERENCE_STEPS = 9
NUM_CHANNELS_LATENTS = 16
VAE_SCALE_FACTOR = 8
PATCH_SIZE = 2
FPATCH_SIZE = 1
MAX_SEQ_LEN = 128
VAE_SCALING_FACTOR = 0.3611
VAE_SHIFT_FACTOR = 0.1159
SAMPLER_MAP = {
"Flow": FlowMatchEulerDiscreteScheduler,
"Flow_Unipc": FlowUniPCMultistepScheduler,
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
}
SAMPLER_NAME = "Flow"
# 默认最终输出,如果不存在 auto 子图则回退到最后一个 cfg 输出
DEFAULT_FINAL_OUTPUT = None
# -----------------------------------------------------------------------------
# 工具函数 (复制自原 launcher 并微调)
# -----------------------------------------------------------------------------
def _infer_module_device(module: torch.nn.Module) -> torch.device:
param = next(module.parameters(), None)
if param is not None:
return param.device
buffer = next(module.buffers(), None)
if buffer is not None:
return buffer.device
return torch.device("cpu")
@contextmanager
def module_to_device(module: torch.nn.Module, target_device: torch.device):
if module is None:
yield module
return
original_device = _infer_module_device(module)
target_device = target_device or original_device
needs_move = original_device != target_device
moved_to_cuda = needs_move and target_device.type == "cuda"
if needs_move:
module.to(target_device)
try:
yield module
finally:
if needs_move:
module.to(original_device)
if moved_to_cuda and torch.cuda.is_available():
cache_device = target_device.index or torch.cuda.current_device()
with torch.cuda.device(cache_device):
torch.cuda.empty_cache()
def _encode_prompt(
tokenizer: AutoTokenizer,
text_encoder: Qwen3ForCausalLM,
prompt: Union[str, List[str]],
device: torch.device,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
if prompt_embeds is not None:
return prompt_embeds
prompts = [prompt] if isinstance(prompt, str) else list(prompt)
for idx, item in enumerate(prompts):
messages = [{"role": "user", "content": item}]
prompts[idx] = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True, enable_thinking=True
)
text_inputs = tokenizer(
prompts,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
with module_to_device(text_encoder, device):
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
return [prompt_embeds[i] for i in range(len(prompt_embeds))]
def encode_prompt(
tokenizer: AutoTokenizer,
text_encoder: Qwen3ForCausalLM,
prompt: Union[str, List[str]],
device: torch.device,
do_classifier_free_guidance: bool,
negative_prompt: Optional[Union[str, List[str]]],
max_sequence_length: int,
) -> Tuple[List[torch.FloatTensor], List[torch.FloatTensor]]:
prompt_embeds = _encode_prompt(
tokenizer, text_encoder, prompt, device, None, max_sequence_length
)
negative_embeds: List[torch.FloatTensor] = []
if do_classifier_free_guidance:
neg = negative_prompt or ""
negative_list = [neg] if isinstance(neg, str) else list(neg)
negative_embeds = _encode_prompt(
tokenizer, text_encoder, negative_list, device, None, max_sequence_length
)
return prompt_embeds, negative_embeds
def _stack_prompt_embeddings(prompt_embeds_input):
if isinstance(prompt_embeds_input, list):
return torch.stack(prompt_embeds_input, dim=0)
return prompt_embeds_input
def prepare_latents(
batch_size: int,
num_channels_latents: int,
height: int,
width: int,
dtype: torch.dtype,
device: torch.device,
generator: torch.Generator,
) -> torch.FloatTensor:
height = 2 * (int(height) // (VAE_SCALE_FACTOR * 2))
width = 2 * (int(width) // (VAE_SCALE_FACTOR * 2))
shape = (batch_size, num_channels_latents, height, width)
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
return latents
def calculate_shift(
image_seq_len: int,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
) -> float:
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
return image_seq_len * m + b
def retrieve_timesteps(
scheduler,
num_inference_steps: int,
device: torch.device,
**kwargs,
):
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
return scheduler.timesteps
# -----------------------------------------------------------------------------
# 参数
# -----------------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="AXModel 推理 (transformer + VAE)")
parser.add_argument("--prompt", type=str, default=None, help="正向提示词,不填则使用预置随机样本")
parser.add_argument("--negative-prompt", type=str, default=NEG_PROMPT, help="反向提示词")
parser.add_argument("--steps", type=int, default=NUM_INFERENCE_STEPS, help="迭代步数")
parser.add_argument("--height", type=int, default=HEIGHT, help="生成高度,需被 16 整除")
parser.add_argument("--width", type=int, default=WIDTH, help="生成宽度,需被 16 整除")
parser.add_argument("--seed", type=int, default=SEED, help="随机种子")
parser.add_argument("--sampler", type=str, choices=list(SAMPLER_MAP.keys()), default=SAMPLER_NAME, help="采样器")
parser.add_argument("--max-seq-len", type=int, default=MAX_SEQ_LEN, help="最大文本长度")
parser.add_argument("--save-dir", type=str, default=str(SAVE_DIR), help="结果输出目录")
parser.add_argument("--transformer-config", type=str, required=True, help="子图配置 json")
parser.add_argument("--transformer-onnx", type=str, default=None, help="原始 transformer onnx(可选,sub_configs 已覆盖可不填)")
parser.add_argument("--transformer-subgraph-dir", type=str, required=True, help="子图 axmodel 目录")
parser.add_argument("--vae-axmodel", type=str, required=True, help="VAE decoder axmodel 路径")
parser.add_argument("--final-output-name", type=str, default=None, help="指定最终输出 tensor 名称,默认自动推断")
parser.add_argument("--save-decoder-input", action="store_true", help="是否保存 decoder 输入 npy")
parser.add_argument("--no-progress", action="store_true", help="关闭进度条输出")
return parser.parse_args()
# -----------------------------------------------------------------------------
# AX transformer 子图执行器
# -----------------------------------------------------------------------------
from scripts.split_quant_onnx_by_subconfigs import SubGraphSpec, sanitize
class AxSplitTransformer:
def __init__(self, config_path: Path, onnx_path: Optional[Path], model_dir: Path):
self.config_path = config_path
self.onnx_path = onnx_path
self.model_dir = model_dir
self._session_cache: Dict[str, AxInferenceSession] = {}
config_specs = self._load_specs()
auto_specs = self._load_auto_specs()
self.specs = config_specs + auto_specs
last_group = auto_specs if auto_specs else config_specs
self.final_outputs = list(last_group[-1].end)
self.default_final_output = DEFAULT_FINAL_OUTPUT or self.final_outputs[0]
def _get_session(self, spec: SubGraphSpec) -> AxInferenceSession:
if spec.label not in self._session_cache:
path = self._expected_path(spec)
self._session_cache[spec.label] = AxInferenceSession(path.as_posix())
logger.info(f"加载子图 session: {spec.label} from {path.name}")
return self._session_cache[spec.label]
def close(self) -> None:
# 显式释放缓存的 session
for key, sess in list(self._session_cache.items()):
try:
del sess
finally:
self._session_cache.pop(key, None)
def _load_specs(self) -> List[SubGraphSpec]:
with self.config_path.open("r", encoding="utf-8") as f:
config = json.load(f)
sub_configs = config.get("compiler", {}).get("sub_configs", [])
if not sub_configs:
raise ValueError("配置文件缺少 compiler.sub_configs")
specs: List[SubGraphSpec] = []
for idx, entry in enumerate(sub_configs):
start = [name for name in entry.get("start_tensor_names", []) if name]
end = [name for name in entry.get("end_tensor_names", []) if name]
if not start or not end:
raise ValueError(f"sub_config[{idx}] 定义不完整")
spec = SubGraphSpec(
label=f"cfg_{idx:02d}",
start=start,
end=end,
node_names=set(),
source="config",
)
specs.append(spec)
return specs
def _load_auto_specs(self) -> List[SubGraphSpec]:
specs: List[SubGraphSpec] = []
for path in sorted(self.model_dir.glob("auto_*.axmodel")):
try:
session = AxInferenceSession(path.as_posix())
inputs = [info.name for info in session.get_inputs() if getattr(info, "name", None)]
outputs = [info.name for info in session.get_outputs() if getattr(info, "name", None)]
# 缓存 session,避免重复打开
self._session_cache[path.stem] = session
except Exception as exc: # pragma: no cover - defensive
logger.warning(f"跳过 {path.name},加载/解析 IO 失败: {exc}")
continue
if not inputs or not outputs:
logger.warning(f"跳过 {path.name},未找到有效的输入/输出定义")
continue
specs.append(
SubGraphSpec(
label=path.stem,
start=inputs,
end=outputs,
node_names=set(),
source="auto",
output_path=path,
)
)
return specs
def _expected_path(self, spec: SubGraphSpec) -> Path:
if spec.output_path is not None:
path = spec.output_path
else:
head = sanitize(spec.start[0]) if spec.start else "const"
tail = sanitize(spec.end[0]) if spec.end else "out"
filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.axmodel"
path = self.model_dir / filename
if not path.exists():
raise FileNotFoundError(f"缺少 AXModel: {path}")
return path
def __call__(
self,
latent_np: np.ndarray,
prompt_np: np.ndarray,
timestep_np: np.ndarray,
final_output_name: Optional[str] = None,
) -> np.ndarray:
tensor_store: Dict[str, np.ndarray] = {
"latent_model_input": latent_np,
"prompt_embeds": prompt_np,
"timestep": timestep_np,
}
executed = set()
target = final_output_name or self.default_final_output
# 就绪驱动执行,单个子图跑完立刻释放 session
while target not in tensor_store:
progressed = False
for spec in self.specs:
if spec.label in executed:
continue
if not all(name in tensor_store for name in spec.start):
continue
session = self._get_session(spec)
inputs = {name: tensor_store[name] for name in spec.start}
outputs = session.run(spec.end, inputs)
for out_name, value in zip(spec.end, outputs):
tensor_store[out_name] = value
executed.add(spec.label)
progressed = True
if not progressed:
missing = [
(spec.label, [name for name in spec.start if name not in tensor_store])
for spec in self.specs
if spec.label not in executed
]
raise RuntimeError(
f"子图调度中断,缺少输入: {missing}; 当前可用: {list(tensor_store.keys())}"
)
return tensor_store[target]
# -----------------------------------------------------------------------------
# 主流程
# -----------------------------------------------------------------------------
def main() -> None:
args = parse_args()
prompt_text = args.prompt if args.prompt is not None else PROMPT
logger.info(f"使用的 prompt: {prompt_text}")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
torch.set_grad_enabled(False)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer")
text_encoder = Qwen3ForCausalLM.from_pretrained(
MODEL_NAME,
subfolder="text_encoder",
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
)
text_encoder.eval()
scheduler_cls = SAMPLER_MAP[args.sampler]
scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler")
image_processor = VaeImageProcessor(vae_scale_factor=VAE_SCALE_FACTOR * 2)
prompt_embeds, _ = encode_prompt(
tokenizer,
text_encoder,
prompt_text,
device,
do_classifier_free_guidance=False,
negative_prompt=args.negative_prompt,
max_sequence_length=args.max_seq_len,
)
latents = prepare_latents(
batch_size=1,
num_channels_latents=NUM_CHANNELS_LATENTS,
height=args.height,
width=args.width,
dtype=torch.float32,
device=device,
generator=torch.Generator(device=device).manual_seed(args.seed),
)
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
timesteps = retrieve_timesteps(scheduler, args.steps, device=device, mu=mu)
onnx_path = Path(args.transformer_onnx) if args.transformer_onnx else None
transformer_runner = AxSplitTransformer(
Path(args.transformer_config),
onnx_path,
Path(args.transformer_subgraph_dir),
)
# 优先使用 auto_* 子图里的 sample 输出,避免误用中间特征导致 shape 对不上
available_outputs = [name for spec in transformer_runner.specs for name in getattr(spec, "end", [])]
preferred_output = "sample" if "sample" in available_outputs else transformer_runner.default_final_output
final_output_name = args.final_output_name or preferred_output
if final_output_name not in available_outputs:
raise ValueError(f"指定的输出 {final_output_name} 不存在,可选: {available_outputs}")
prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds)
iterator = timesteps if args.no_progress else tqdm(timesteps, desc="AX Denoising", dynamic_ncols=True)
for t in iterator:
timestep = t.expand(latents.shape[0])
timestep_model_input = (1000 - timestep) / 1000
latent_model_input = latents.to(torch.float32)
latent_np = latent_model_input.unsqueeze(2).to(dtype=torch.float32).cpu().numpy()
prompt_np = prompt_embeds_tensor.to(dtype=torch.float32).cpu().numpy()
timestep_np = timestep_model_input.to(dtype=torch.float32).cpu().numpy()
model_out = transformer_runner(latent_np, prompt_np, timestep_np, final_output_name)
if model_out.ndim == 5 and model_out.shape[2] == 1:
model_out = np.squeeze(model_out, axis=2)
model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=torch.float32)
if model_out_tensor.dim() == 5 and model_out_tensor.size(2) == 1:
model_out_tensor = model_out_tensor.squeeze(2)
noise_pred = -model_out_tensor
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
# 释放 transformer 缓存的 session
transformer_runner.close()
latents = (latents / VAE_SCALING_FACTOR) + VAE_SHIFT_FACTOR
decoder_input = latents.to(dtype=torch.float32).cpu().numpy()
if args.save_decoder_input:
save_dir_path = Path(args.save_dir)
save_dir_path.mkdir(parents=True, exist_ok=True)
np.save(save_dir_path / "decoder_input.npy", decoder_input)
logger.info("已保存 decoder 输入为 npy")
del transformer_runner
vae_decoder_session = AxInferenceSession(Path(args.vae_axmodel).as_posix())
if decoder_input.ndim == 5 and decoder_input.shape[2] == 1:
decoder_input = np.squeeze(decoder_input, axis=2)
image = vae_decoder_session.run(None, {"latents": decoder_input})[0]
del vae_decoder_session
image = torch.from_numpy(image).to(device=device, dtype=torch.float32)
image = image_processor.postprocess(image, output_type="pil")
save_dir = Path(args.save_dir)
save_dir.mkdir(parents=True, exist_ok=True)
target_path = save_dir / f"z_image_axmodel_{prompt_idx}.png"
image[0].save(target_path)
logger.info(f"AXModel 推理完成,结果保存到 {target_path}")
if __name__ == "__main__":
"""
# 512x512 生成示例命令:
python3 examples/z_image_fun/launcher_axmodel.py \
--transformer-config pulsar2_configs/transformers_subgraph.json \
--transformer-subgraph-dir comliled_subgraph_from_all_onnx \
--vae-axmodel vae_decoder.axmodel
# 1728x992 生成示例命令:
python3 examples/z_image_fun/launcher_axmodel.py \
--transformer-config pulsar2_configs/transformers_subgraph_1728x992.json \
--transformer-subgraph-dir transformers_body_only_1728_992_split_onnx \
--vae-axmodel onnx-models-1728x992/vae_decoder_simp_slim.axmodel \
--max-seq-len 256 \
--height 1728 --width 992
"""
main()