| | |
| | """使用 AXModel 推理链路 (transformer + VAE decoder)。""" |
| | from __future__ import annotations |
| |
|
| | import argparse |
| | import json |
| | import os |
| | import random |
| | import sys |
| | from contextlib import contextmanager |
| | from pathlib import Path |
| | from typing import Dict, Iterable, List, Optional, Tuple, Union |
| |
|
| | current_file_path = os.path.abspath(__file__) |
| | project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
| | for project_root in project_roots: |
| | sys.path.insert(0, project_root) if project_root not in sys.path else None |
| |
|
| | SCRIPT_DIR = Path(__file__).resolve().parent |
| | REPO_ROOT = SCRIPT_DIR.parents[2] |
| | if REPO_ROOT.as_posix() not in sys.path: |
| | sys.path.insert(0, REPO_ROOT.as_posix()) |
| |
|
| | import numpy as np |
| | import torch |
| | from axengine import InferenceSession as AxInferenceSession |
| | from diffusers import FlowMatchEulerDiscreteScheduler |
| | from diffusers.image_processor import VaeImageProcessor |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from loguru import logger |
| | from tqdm import tqdm |
| |
|
| | from videox_fun.models import AutoTokenizer, Qwen3ForCausalLM |
| | from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| | from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| | from videox_fun.utils.utils import get_image_latent |
| |
|
| |
|
| | |
| | |
| | |
| | MODEL_NAME = "models/Diffusion_Transformer/Z-Image-Turbo/" |
| | CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "config" / "z_image" / "z_image.yaml" |
| | TRANSFORMER_CONFIG_PATH = REPO_ROOT / "VideoX-Fun" / "pulsar2_configs" / "transformers_subgraph.json" |
| | TRANSFORMER_ONNX_PATH = REPO_ROOT / "VideoX-Fun" / "compiled_subgraph_from_onnx" / "frontend" / "optimized_quant_axmodel.onnx" |
| | TRANSFORMER_AXMODEL_DIR = REPO_ROOT / "VideoX-Fun" / "comliled_subgraph_from_all_onnx" |
| | VAE_DECODER_AXMODEL = REPO_ROOT / "VideoX-Fun" / "vae_decoder.axmodel" |
| | SAVE_DIR = REPO_ROOT / "VideoX-Fun" / "samples" / "z-image-t2i-axmodel" |
| |
|
| | |
| | |
| | |
| | DEFAULT_PROMPTS = [ |
| | "(masterpiece, best quality) solo female on a tropical beach, golden hour rim light, cinematic grading", |
| | "nighttime cyberpunk boulevard, neon reflections on wet asphalt, volumetric fog, wide shot", |
| | "sunrise over alpine mountains, low clouds in valleys, god rays, ultra-detailed landscape", |
| | "modern minimal living room, soft natural light, Scandinavian design, high-resolution interior render", |
| | "classical oil painting of a renaissance noblewoman, chiaroscuro lighting, rich textures", |
| | "macro photography of a dewdrop on a leaf, extreme detail, shallow depth of field", |
| | "futuristic sports car parked under neon lights, glossy paint, cinematic 35mm look", |
| | "ancient library with towering bookshelves, warm candlelight, dust motes in air", |
| | "portrait of an astronaut in full suit, visor reflection showing earth, studio lighting", |
| | "stormy sea with a lone lighthouse, crashing waves, dramatic clouds, long exposure feel", |
| | "cybernetic samurai standing in rain, backlit silhouette, moody blue-orange palette", |
| | "lush rainforest waterfall, soft mist, saturated greens, wide-angle composition", |
| | "product shot of a smartwatch on marble, softbox lighting, crisp shadows, advertisement style", |
| | "architectural exterior of a glass skyscraper at dusk, warm interior lights, reflections", |
| | "vintage film photograph of a 1950s diner at night, grain and halation, neon signage", |
| | "hyperrealistic bowl of ramen, steam rising, glossy broth, detailed toppings", |
| | "fantasy castle on a floating island, waterfalls falling into clouds, sunset lighting", |
| | "high-fashion editorial portrait, dramatic chiaroscuro, sharp focus on eyes", |
| | "aerial view of winding river through autumn forest, golden and crimson leaves", |
| | "studio shot of running shoes mid-air, motion blur trails, vibrant background gradient", |
| | "noir city alley in the 1940s, hard shadows, rain-slick pavement, moody atmosphere", |
| | "desert caravan at twilight, silhouettes of camels, soft purple sky, cinematic scope", |
| | "close-up of a mechanical watch movement, intricate gears, metallic reflections", |
| | "bioluminescent underwater reef, glowing corals, schools of fish, deep blue tones", |
| | "portrait of an elderly man with weathered face, soft window light, fine skin detail", |
| | "snowy village at night, warm cabin lights, smoke from chimneys, peaceful mood", |
| | "futuristic data center aisle, cool cyan lighting, depth and symmetry", |
| | "oil painting of a bowl of fruit in Dutch masters style, rich textures, dramatic lighting", |
| | "sunlit meadow with wildflowers, shallow depth of field, pastel color palette", |
| | "sci-fi corridor with volumetric light shafts, pristine white surfaces, wide lens", |
| | "luxury wristwatch on black velvet, high contrast, advertisement macro shot", |
| | "medieval marketplace at dawn, merchants setting up, soft warm light, lively details", |
| | "((masterpiece,best quality))1 young beautiful girl,ultra detailed,official art,unity 8k wallpaper,masterpiece, best quality, official art, extremely detailed CG unity 8k wallpaper, highly detailed, 1 girl, aqua eyes, light smile, ((grey hair)), hair flower, bracelet, choker, ribbon, JK, look at viewer, on the beach, in summer," |
| | ] |
| | prompt_idx = random.randint(0, len(DEFAULT_PROMPTS) - 1) |
| | PROMPT = DEFAULT_PROMPTS[prompt_idx] |
| | NEG_PROMPT = " " |
| | GUIDANCE_SCALE = 0.0 |
| | SEED = 42 |
| | HEIGHT, WIDTH = 512, 512 |
| | NUM_INFERENCE_STEPS = 9 |
| | NUM_CHANNELS_LATENTS = 16 |
| | VAE_SCALE_FACTOR = 8 |
| | PATCH_SIZE = 2 |
| | FPATCH_SIZE = 1 |
| | MAX_SEQ_LEN = 128 |
| | VAE_SCALING_FACTOR = 0.3611 |
| | VAE_SHIFT_FACTOR = 0.1159 |
| |
|
| | SAMPLER_MAP = { |
| | "Flow": FlowMatchEulerDiscreteScheduler, |
| | "Flow_Unipc": FlowUniPCMultistepScheduler, |
| | "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| | } |
| | SAMPLER_NAME = "Flow" |
| |
|
| | |
| | DEFAULT_FINAL_OUTPUT = None |
| |
|
| | |
| | |
| | |
| |
|
| | def _infer_module_device(module: torch.nn.Module) -> torch.device: |
| | param = next(module.parameters(), None) |
| | if param is not None: |
| | return param.device |
| | buffer = next(module.buffers(), None) |
| | if buffer is not None: |
| | return buffer.device |
| | return torch.device("cpu") |
| |
|
| |
|
| | @contextmanager |
| | def module_to_device(module: torch.nn.Module, target_device: torch.device): |
| | if module is None: |
| | yield module |
| | return |
| | original_device = _infer_module_device(module) |
| | target_device = target_device or original_device |
| | needs_move = original_device != target_device |
| | moved_to_cuda = needs_move and target_device.type == "cuda" |
| | if needs_move: |
| | module.to(target_device) |
| | try: |
| | yield module |
| | finally: |
| | if needs_move: |
| | module.to(original_device) |
| | if moved_to_cuda and torch.cuda.is_available(): |
| | cache_device = target_device.index or torch.cuda.current_device() |
| | with torch.cuda.device(cache_device): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | def _encode_prompt( |
| | tokenizer: AutoTokenizer, |
| | text_encoder: Qwen3ForCausalLM, |
| | prompt: Union[str, List[str]], |
| | device: torch.device, |
| | prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| | max_sequence_length: int = 512, |
| | ) -> List[torch.FloatTensor]: |
| | if prompt_embeds is not None: |
| | return prompt_embeds |
| | prompts = [prompt] if isinstance(prompt, str) else list(prompt) |
| | for idx, item in enumerate(prompts): |
| | messages = [{"role": "user", "content": item}] |
| | prompts[idx] = tokenizer.apply_chat_template( |
| | messages, tokenize=False, add_generation_prompt=True, enable_thinking=True |
| | ) |
| | text_inputs = tokenizer( |
| | prompts, |
| | padding="max_length", |
| | max_length=max_sequence_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| | text_input_ids = text_inputs.input_ids.to(device) |
| | prompt_masks = text_inputs.attention_mask.to(device).bool() |
| | with module_to_device(text_encoder, device): |
| | prompt_embeds = text_encoder( |
| | input_ids=text_input_ids, |
| | attention_mask=prompt_masks, |
| | output_hidden_states=True, |
| | ).hidden_states[-2] |
| | return [prompt_embeds[i] for i in range(len(prompt_embeds))] |
| |
|
| |
|
| | def encode_prompt( |
| | tokenizer: AutoTokenizer, |
| | text_encoder: Qwen3ForCausalLM, |
| | prompt: Union[str, List[str]], |
| | device: torch.device, |
| | do_classifier_free_guidance: bool, |
| | negative_prompt: Optional[Union[str, List[str]]], |
| | max_sequence_length: int, |
| | ) -> Tuple[List[torch.FloatTensor], List[torch.FloatTensor]]: |
| | prompt_embeds = _encode_prompt( |
| | tokenizer, text_encoder, prompt, device, None, max_sequence_length |
| | ) |
| | negative_embeds: List[torch.FloatTensor] = [] |
| | if do_classifier_free_guidance: |
| | neg = negative_prompt or "" |
| | negative_list = [neg] if isinstance(neg, str) else list(neg) |
| | negative_embeds = _encode_prompt( |
| | tokenizer, text_encoder, negative_list, device, None, max_sequence_length |
| | ) |
| | return prompt_embeds, negative_embeds |
| |
|
| |
|
| | def _stack_prompt_embeddings(prompt_embeds_input): |
| | if isinstance(prompt_embeds_input, list): |
| | return torch.stack(prompt_embeds_input, dim=0) |
| | return prompt_embeds_input |
| |
|
| |
|
| | def prepare_latents( |
| | batch_size: int, |
| | num_channels_latents: int, |
| | height: int, |
| | width: int, |
| | dtype: torch.dtype, |
| | device: torch.device, |
| | generator: torch.Generator, |
| | ) -> torch.FloatTensor: |
| | height = 2 * (int(height) // (VAE_SCALE_FACTOR * 2)) |
| | width = 2 * (int(width) // (VAE_SCALE_FACTOR * 2)) |
| | shape = (batch_size, num_channels_latents, height, width) |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | return latents |
| |
|
| |
|
| | def calculate_shift( |
| | image_seq_len: int, |
| | base_seq_len: int = 256, |
| | max_seq_len: int = 4096, |
| | base_shift: float = 0.5, |
| | max_shift: float = 1.15, |
| | ) -> float: |
| | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| | b = base_shift - m * base_seq_len |
| | return image_seq_len * m + b |
| |
|
| |
|
| | def retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps: int, |
| | device: torch.device, |
| | **kwargs, |
| | ): |
| | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| | return scheduler.timesteps |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser(description="AXModel 推理 (transformer + VAE)") |
| | parser.add_argument("--prompt", type=str, default=None, help="正向提示词,不填则使用预置随机样本") |
| | parser.add_argument("--negative-prompt", type=str, default=NEG_PROMPT, help="反向提示词") |
| | parser.add_argument("--steps", type=int, default=NUM_INFERENCE_STEPS, help="迭代步数") |
| | parser.add_argument("--height", type=int, default=HEIGHT, help="生成高度,需被 16 整除") |
| | parser.add_argument("--width", type=int, default=WIDTH, help="生成宽度,需被 16 整除") |
| | parser.add_argument("--seed", type=int, default=SEED, help="随机种子") |
| | parser.add_argument("--sampler", type=str, choices=list(SAMPLER_MAP.keys()), default=SAMPLER_NAME, help="采样器") |
| | parser.add_argument("--max-seq-len", type=int, default=MAX_SEQ_LEN, help="最大文本长度") |
| | parser.add_argument("--save-dir", type=str, default=str(SAVE_DIR), help="结果输出目录") |
| | parser.add_argument("--transformer-config", type=str, required=True, help="子图配置 json") |
| | parser.add_argument("--transformer-onnx", type=str, default=None, help="原始 transformer onnx(可选,sub_configs 已覆盖可不填)") |
| | parser.add_argument("--transformer-subgraph-dir", type=str, required=True, help="子图 axmodel 目录") |
| | parser.add_argument("--vae-axmodel", type=str, required=True, help="VAE decoder axmodel 路径") |
| | parser.add_argument("--final-output-name", type=str, default=None, help="指定最终输出 tensor 名称,默认自动推断") |
| | parser.add_argument("--save-decoder-input", action="store_true", help="是否保存 decoder 输入 npy") |
| | parser.add_argument("--no-progress", action="store_true", help="关闭进度条输出") |
| | return parser.parse_args() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | from scripts.split_quant_onnx_by_subconfigs import SubGraphSpec, sanitize |
| |
|
| |
|
| | class AxSplitTransformer: |
| | def __init__(self, config_path: Path, onnx_path: Optional[Path], model_dir: Path): |
| | self.config_path = config_path |
| | self.onnx_path = onnx_path |
| | self.model_dir = model_dir |
| | self._session_cache: Dict[str, AxInferenceSession] = {} |
| | config_specs = self._load_specs() |
| | auto_specs = self._load_auto_specs() |
| | self.specs = config_specs + auto_specs |
| | last_group = auto_specs if auto_specs else config_specs |
| | self.final_outputs = list(last_group[-1].end) |
| | self.default_final_output = DEFAULT_FINAL_OUTPUT or self.final_outputs[0] |
| |
|
| | def _get_session(self, spec: SubGraphSpec) -> AxInferenceSession: |
| | if spec.label not in self._session_cache: |
| | path = self._expected_path(spec) |
| | self._session_cache[spec.label] = AxInferenceSession(path.as_posix()) |
| | logger.info(f"加载子图 session: {spec.label} from {path.name}") |
| | return self._session_cache[spec.label] |
| |
|
| | def close(self) -> None: |
| | |
| | for key, sess in list(self._session_cache.items()): |
| | try: |
| | del sess |
| | finally: |
| | self._session_cache.pop(key, None) |
| |
|
| | def _load_specs(self) -> List[SubGraphSpec]: |
| | with self.config_path.open("r", encoding="utf-8") as f: |
| | config = json.load(f) |
| | sub_configs = config.get("compiler", {}).get("sub_configs", []) |
| | if not sub_configs: |
| | raise ValueError("配置文件缺少 compiler.sub_configs") |
| | specs: List[SubGraphSpec] = [] |
| | for idx, entry in enumerate(sub_configs): |
| | start = [name for name in entry.get("start_tensor_names", []) if name] |
| | end = [name for name in entry.get("end_tensor_names", []) if name] |
| | if not start or not end: |
| | raise ValueError(f"sub_config[{idx}] 定义不完整") |
| | spec = SubGraphSpec( |
| | label=f"cfg_{idx:02d}", |
| | start=start, |
| | end=end, |
| | node_names=set(), |
| | source="config", |
| | ) |
| | specs.append(spec) |
| | return specs |
| |
|
| | def _load_auto_specs(self) -> List[SubGraphSpec]: |
| | specs: List[SubGraphSpec] = [] |
| | for path in sorted(self.model_dir.glob("auto_*.axmodel")): |
| | try: |
| | session = AxInferenceSession(path.as_posix()) |
| | inputs = [info.name for info in session.get_inputs() if getattr(info, "name", None)] |
| | outputs = [info.name for info in session.get_outputs() if getattr(info, "name", None)] |
| | |
| | self._session_cache[path.stem] = session |
| | except Exception as exc: |
| | logger.warning(f"跳过 {path.name},加载/解析 IO 失败: {exc}") |
| | continue |
| |
|
| | if not inputs or not outputs: |
| | logger.warning(f"跳过 {path.name},未找到有效的输入/输出定义") |
| | continue |
| |
|
| | specs.append( |
| | SubGraphSpec( |
| | label=path.stem, |
| | start=inputs, |
| | end=outputs, |
| | node_names=set(), |
| | source="auto", |
| | output_path=path, |
| | ) |
| | ) |
| | return specs |
| |
|
| | def _expected_path(self, spec: SubGraphSpec) -> Path: |
| | if spec.output_path is not None: |
| | path = spec.output_path |
| | else: |
| | head = sanitize(spec.start[0]) if spec.start else "const" |
| | tail = sanitize(spec.end[0]) if spec.end else "out" |
| | filename = f"{spec.label}_{head}_to_{tail}_{spec.source}.axmodel" |
| | path = self.model_dir / filename |
| | if not path.exists(): |
| | raise FileNotFoundError(f"缺少 AXModel: {path}") |
| | return path |
| |
|
| | def __call__( |
| | self, |
| | latent_np: np.ndarray, |
| | prompt_np: np.ndarray, |
| | timestep_np: np.ndarray, |
| | final_output_name: Optional[str] = None, |
| | ) -> np.ndarray: |
| | tensor_store: Dict[str, np.ndarray] = { |
| | "latent_model_input": latent_np, |
| | "prompt_embeds": prompt_np, |
| | "timestep": timestep_np, |
| | } |
| | executed = set() |
| | target = final_output_name or self.default_final_output |
| |
|
| | |
| | while target not in tensor_store: |
| | progressed = False |
| | for spec in self.specs: |
| | if spec.label in executed: |
| | continue |
| | if not all(name in tensor_store for name in spec.start): |
| | continue |
| | session = self._get_session(spec) |
| | inputs = {name: tensor_store[name] for name in spec.start} |
| | outputs = session.run(spec.end, inputs) |
| | for out_name, value in zip(spec.end, outputs): |
| | tensor_store[out_name] = value |
| | executed.add(spec.label) |
| | progressed = True |
| | if not progressed: |
| | missing = [ |
| | (spec.label, [name for name in spec.start if name not in tensor_store]) |
| | for spec in self.specs |
| | if spec.label not in executed |
| | ] |
| | raise RuntimeError( |
| | f"子图调度中断,缺少输入: {missing}; 当前可用: {list(tensor_store.keys())}" |
| | ) |
| |
|
| | return tensor_store[target] |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | prompt_text = args.prompt if args.prompt is not None else PROMPT |
| | logger.info(f"使用的 prompt: {prompt_text}") |
| |
|
| | device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
| | torch.set_grad_enabled(False) |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, subfolder="tokenizer") |
| | text_encoder = Qwen3ForCausalLM.from_pretrained( |
| | MODEL_NAME, |
| | subfolder="text_encoder", |
| | torch_dtype=torch.bfloat16, |
| | low_cpu_mem_usage=True, |
| | ) |
| | text_encoder.eval() |
| |
|
| | scheduler_cls = SAMPLER_MAP[args.sampler] |
| | scheduler = scheduler_cls.from_pretrained(MODEL_NAME, subfolder="scheduler") |
| |
|
| | image_processor = VaeImageProcessor(vae_scale_factor=VAE_SCALE_FACTOR * 2) |
| |
|
| | prompt_embeds, _ = encode_prompt( |
| | tokenizer, |
| | text_encoder, |
| | prompt_text, |
| | device, |
| | do_classifier_free_guidance=False, |
| | negative_prompt=args.negative_prompt, |
| | max_sequence_length=args.max_seq_len, |
| | ) |
| |
|
| | latents = prepare_latents( |
| | batch_size=1, |
| | num_channels_latents=NUM_CHANNELS_LATENTS, |
| | height=args.height, |
| | width=args.width, |
| | dtype=torch.float32, |
| | device=device, |
| | generator=torch.Generator(device=device).manual_seed(args.seed), |
| | ) |
| |
|
| | image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
| | mu = calculate_shift( |
| | image_seq_len, |
| | scheduler.config.get("base_image_seq_len", 256), |
| | scheduler.config.get("max_image_seq_len", 4096), |
| | scheduler.config.get("base_shift", 0.5), |
| | scheduler.config.get("max_shift", 1.15), |
| | ) |
| | timesteps = retrieve_timesteps(scheduler, args.steps, device=device, mu=mu) |
| |
|
| | onnx_path = Path(args.transformer_onnx) if args.transformer_onnx else None |
| | transformer_runner = AxSplitTransformer( |
| | Path(args.transformer_config), |
| | onnx_path, |
| | Path(args.transformer_subgraph_dir), |
| | ) |
| | |
| | available_outputs = [name for spec in transformer_runner.specs for name in getattr(spec, "end", [])] |
| | preferred_output = "sample" if "sample" in available_outputs else transformer_runner.default_final_output |
| | final_output_name = args.final_output_name or preferred_output |
| | if final_output_name not in available_outputs: |
| | raise ValueError(f"指定的输出 {final_output_name} 不存在,可选: {available_outputs}") |
| |
|
| | prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds) |
| |
|
| | iterator = timesteps if args.no_progress else tqdm(timesteps, desc="AX Denoising", dynamic_ncols=True) |
| | for t in iterator: |
| | timestep = t.expand(latents.shape[0]) |
| | timestep_model_input = (1000 - timestep) / 1000 |
| |
|
| | latent_model_input = latents.to(torch.float32) |
| | latent_np = latent_model_input.unsqueeze(2).to(dtype=torch.float32).cpu().numpy() |
| | prompt_np = prompt_embeds_tensor.to(dtype=torch.float32).cpu().numpy() |
| | timestep_np = timestep_model_input.to(dtype=torch.float32).cpu().numpy() |
| |
|
| | model_out = transformer_runner(latent_np, prompt_np, timestep_np, final_output_name) |
| | if model_out.ndim == 5 and model_out.shape[2] == 1: |
| | model_out = np.squeeze(model_out, axis=2) |
| | model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=torch.float32) |
| | if model_out_tensor.dim() == 5 and model_out_tensor.size(2) == 1: |
| | model_out_tensor = model_out_tensor.squeeze(2) |
| | noise_pred = -model_out_tensor |
| |
|
| | latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] |
| |
|
| | |
| | transformer_runner.close() |
| |
|
| | latents = (latents / VAE_SCALING_FACTOR) + VAE_SHIFT_FACTOR |
| | decoder_input = latents.to(dtype=torch.float32).cpu().numpy() |
| |
|
| | if args.save_decoder_input: |
| | save_dir_path = Path(args.save_dir) |
| | save_dir_path.mkdir(parents=True, exist_ok=True) |
| | np.save(save_dir_path / "decoder_input.npy", decoder_input) |
| | logger.info("已保存 decoder 输入为 npy") |
| |
|
| | del transformer_runner |
| |
|
| | vae_decoder_session = AxInferenceSession(Path(args.vae_axmodel).as_posix()) |
| |
|
| | if decoder_input.ndim == 5 and decoder_input.shape[2] == 1: |
| | decoder_input = np.squeeze(decoder_input, axis=2) |
| | image = vae_decoder_session.run(None, {"latents": decoder_input})[0] |
| |
|
| | del vae_decoder_session |
| | image = torch.from_numpy(image).to(device=device, dtype=torch.float32) |
| | image = image_processor.postprocess(image, output_type="pil") |
| |
|
| | save_dir = Path(args.save_dir) |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | target_path = save_dir / f"z_image_axmodel_{prompt_idx}.png" |
| | image[0].save(target_path) |
| | logger.info(f"AXModel 推理完成,结果保存到 {target_path}") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | # 512x512 生成示例命令: |
| | python3 examples/z_image_fun/launcher_axmodel.py \ |
| | --transformer-config pulsar2_configs/transformers_subgraph.json \ |
| | --transformer-subgraph-dir comliled_subgraph_from_all_onnx \ |
| | --vae-axmodel vae_decoder.axmodel |
| | |
| | # 1728x992 生成示例命令: |
| | python3 examples/z_image_fun/launcher_axmodel.py \ |
| | --transformer-config pulsar2_configs/transformers_subgraph_1728x992.json \ |
| | --transformer-subgraph-dir transformers_body_only_1728_992_split_onnx \ |
| | --vae-axmodel onnx-models-1728x992/vae_decoder_simp_slim.axmodel \ |
| | --max-seq-len 256 \ |
| | --height 1728 --width 992 |
| | """ |
| |
|
| | main() |