| import os |
| import sys |
| import numpy as np |
| import torch |
| import onnxruntime as ort |
| from contextlib import contextmanager |
| from typing import List, Optional, Union |
|
|
| from omegaconf import OmegaConf |
| from PIL import Image |
| from loguru import logger |
|
|
| current_file_path = os.path.abspath(__file__) |
| project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
| for project_root in project_roots: |
| sys.path.insert(0, project_root) if project_root not in sys.path else None |
| repo_root = project_roots[-1] |
| onnx_models_dir = os.path.join(repo_root, "onnx-models") |
| |
| |
| vae_encoder_onnx_path = os.path.join(onnx_models_dir, "vae_encoder_simp_slim.onnx") |
| vae_decoder_onnx_path = os.path.join(onnx_models_dir, "vae_decoder_simp_slim.onnx") |
| use_transformer_onnx = True |
| transformer_body_onnx_path = os.path.join(onnx_models_dir, "z_image_transformer_body_only_simp_slim.onnx") |
| onnx_export_dtype = torch.float16 |
| _printed_onnx_debug = False |
|
|
|
|
| def _select_onnx_providers(): |
| avail = ort.get_available_providers() |
| if "CUDAExecutionProvider" in avail: |
| return ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| if "AzureExecutionProvider" in avail: |
| return ["AzureExecutionProvider", "CPUExecutionProvider"] |
| return ["CPUExecutionProvider"] |
|
|
|
|
| onnx_providers = _select_onnx_providers() |
|
|
| from diffusers import FlowMatchEulerDiscreteScheduler |
| from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
| from videox_fun.models import (AutoencoderKL, AutoTokenizer, |
| Qwen3ForCausalLM, ZImageTransformer2DModel) |
| from diffusers.utils.torch_utils import randn_tensor |
| from videox_fun.utils.utils import get_image_latent |
|
|
|
|
| torch.set_grad_enabled(False) |
| _ort_sessions = {} |
| def _ort_type_to_torch(type_str: str): |
| if type_str is None: |
| return None |
| if "float16" in type_str: |
| return torch.float16 |
| if "bfloat16" in type_str: |
| return torch.bfloat16 |
| if "float" in type_str: |
| return torch.float32 |
| return None |
|
|
|
|
| def ort_inference(onnx_model_path: str, inputs: dict, providers=None): |
| providers = providers or ["CPUExecutionProvider"] |
| cache_key = (onnx_model_path, tuple(providers)) |
| if cache_key not in _ort_sessions: |
| if not os.path.exists(onnx_model_path): |
| raise FileNotFoundError(f"ONNX model not found: {onnx_model_path}") |
| _ort_sessions[cache_key] = ort.InferenceSession(onnx_model_path, providers=providers) |
| session = _ort_sessions[cache_key] |
|
|
| |
| input_type_map = {i.name: _ort_type_to_torch(getattr(i, "type", None)) for i in session.get_inputs()} |
| inputs_onnx = {} |
| for k, v in inputs.items(): |
| target_torch_dtype = input_type_map.get(k) |
| if isinstance(v, np.ndarray): |
| if target_torch_dtype is not None: |
| np_dtype = np.float16 if target_torch_dtype == torch.float16 else np.float32 |
| if v.dtype != np_dtype: |
| v = v.astype(np_dtype) |
| inputs_onnx[k] = v |
| elif torch.is_tensor(v): |
| if target_torch_dtype is not None and v.dtype != target_torch_dtype: |
| v = v.to(dtype=target_torch_dtype) |
| elif v.dtype == torch.bfloat16: |
| |
| v = v.to(dtype=torch.float32) |
| inputs_onnx[k] = v.detach().to("cpu").numpy() |
| else: |
| raise TypeError(f"Unsupported input type for key {k}: {type(v)}") |
| return session.run(None, inputs_onnx) |
|
|
|
|
| def _infer_module_device(module: torch.nn.Module) -> torch.device: |
| param = next(module.parameters(), None) |
| if param is not None: |
| return param.device |
| buffer = next(module.buffers(), None) |
| if buffer is not None: |
| return buffer.device |
| return torch.device("cpu") |
|
|
|
|
| @contextmanager |
| def module_to_device(module: torch.nn.Module, target_device: torch.device): |
| if module is None: |
| yield module |
| return |
| original_device = _infer_module_device(module) |
| target_device = target_device or original_device |
| needs_move = original_device != target_device |
| moved_to_cuda = needs_move and target_device.type == "cuda" |
| if needs_move: |
| module.to(target_device) |
| try: |
| yield module |
| finally: |
| if needs_move: |
| module.to(original_device) |
| if moved_to_cuda and torch.cuda.is_available(): |
| cache_device = target_device.index |
| if cache_device is None: |
| cache_device = torch.cuda.current_device() |
| with torch.cuda.device(cache_device): |
| torch.cuda.empty_cache() |
|
|
|
|
| |
| config_path_default = "config/z_image/z_image.yaml" |
| model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" |
|
|
| |
| weight_dtype = torch.bfloat16 |
|
|
| control_context_scale = 0.75 |
|
|
| device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, subfolder="tokenizer" |
| ) |
| text_encoder = Qwen3ForCausalLM.from_pretrained( |
| model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
| low_cpu_mem_usage=True, |
| ) |
| text_encoder.eval() |
|
|
|
|
| def _encode_prompt( |
| prompt: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| max_sequence_length: int = 512, |
| ) -> List[torch.FloatTensor]: |
| device = device or _infer_module_device(text_encoder) |
|
|
| if prompt_embeds is not None: |
| return prompt_embeds |
|
|
| if isinstance(prompt, str): |
| prompt = [prompt] |
|
|
| for i, prompt_item in enumerate(prompt): |
| messages = [ |
| {"role": "user", "content": prompt_item}, |
| ] |
| prompt_item = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=True, |
| ) |
| prompt[i] = prompt_item |
|
|
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids.to(device) |
| prompt_masks = text_inputs.attention_mask.to(device).bool() |
|
|
| with module_to_device(text_encoder, device): |
| prompt_embeds = text_encoder( |
| input_ids=text_input_ids, |
| attention_mask=prompt_masks, |
| output_hidden_states=True, |
| ).hidden_states[-2] |
|
|
| embeddings_list = [] |
| for i in range(len(prompt_embeds)): |
| |
| embeddings_list.append(prompt_embeds[i]) |
| return embeddings_list |
|
|
|
|
| def encode_prompt( |
| prompt: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| do_classifier_free_guidance: bool = True, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| max_sequence_length: int = 512, |
| ): |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| prompt_embeds = _encode_prompt( |
| prompt=prompt, |
| device=device, |
| prompt_embeds=prompt_embeds, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| if do_classifier_free_guidance: |
| if negative_prompt is None: |
| negative_prompt = ["" for _ in prompt] |
| else: |
| negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
| assert len(prompt) == len(negative_prompt) |
| negative_prompt_embeds = _encode_prompt( |
| prompt=negative_prompt, |
| device=device, |
| prompt_embeds=negative_prompt_embeds, |
| max_sequence_length=max_sequence_length, |
| ) |
| else: |
| negative_prompt_embeds = [] |
| return prompt_embeds, negative_prompt_embeds |
|
|
|
|
| def _stack_prompt_embeddings(prompt_embeds_input): |
| if isinstance(prompt_embeds_input, list): |
| return torch.stack(prompt_embeds_input, dim=0) |
| return prompt_embeds_input |
|
|
|
|
| def _onnx_shape_compatible(model_path: str, providers, latent_shape, prompt_shape, verbose=False): |
| try: |
| sess = ort.InferenceSession(model_path, providers=providers) |
| inputs = {i.name: i for i in sess.get_inputs()} |
| ok = True |
| mismatch_msgs = [] |
| if "latent_model_input" in inputs: |
| shape = inputs["latent_model_input"].shape |
| exp_h, exp_w = shape[3], shape[4] |
| if isinstance(exp_h, int) and exp_h != latent_shape[3]: |
| ok = False |
| mismatch_msgs.append(f"latent_h expected {exp_h}, got {latent_shape[3]}") |
| if isinstance(exp_w, int) and exp_w != latent_shape[4]: |
| ok = False |
| mismatch_msgs.append(f"latent_w expected {latent_w} got {latent_shape[4]}") |
| if "prompt_embeds" in inputs: |
| pshape = inputs["prompt_embeds"].shape |
| exp_seq = pshape[1] |
| if isinstance(exp_seq, int) and exp_seq != prompt_shape[1]: |
| ok = False |
| mismatch_msgs.append(f"seq_len expected {exp_seq}, got {prompt_shape[1]}") |
| if verbose and (not ok or mismatch_msgs): |
| print(f"[DEBUG] ONNX shape check for {model_path}") |
| print(f" providers={providers}") |
| print(f" model latent shape={inputs.get('latent_model_input').shape if 'latent_model_input' in inputs else 'n/a'}") |
| print(f" model prompt shape={inputs.get('prompt_embeds').shape if 'prompt_embeds' in inputs else 'n/a'}") |
| print(f" runtime latent shape={latent_shape}, prompt shape={prompt_shape}") |
| if mismatch_msgs: |
| print(f" mismatch: {', '.join(mismatch_msgs)}") |
| return ok |
| except Exception as exc: |
| print(f"ONNX shape check failed for {model_path}: {exc}") |
| return True |
|
|
|
|
| def prepare_latents( |
| batch_size, |
| num_channels_latents, |
| height, |
| width, |
| dtype, |
| device, |
| generator, |
| latents=None, |
| ): |
| height = 2 * (int(height) // (vae_scale_factor * 2)) |
| width = 2 * (int(width) // (vae_scale_factor * 2)) |
|
|
| shape = (batch_size, num_channels_latents, height, width) |
|
|
| if latents is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| else: |
| if latents.shape != shape: |
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
| latents = latents.to(device) |
| return latents |
|
|
|
|
| |
| GPU_memory_mode = "model_full_load" |
| ulysses_degree = 1 |
| ring_degree = 1 |
| fsdp_dit = False |
| fsdp_text_encoder = False |
| compile_dit = False |
|
|
| sampler_name = "Flow" |
| transformer_path = None |
| vae_path = None |
| lora_path = None |
|
|
| sample_size = [512, 512] |
| prompt = "(masterpiece, best quality), 1 girl on the beach" |
| negative_prompt = " " |
| guidance_scale = 0.0 |
| seed = 42 |
| num_inference_steps = 9 |
| lora_weight = 0.55 |
| save_path = "samples/z-image-t2i-nocontrol" |
|
|
| def _resolve_config_path(path: str) -> Optional[str]: |
| candidate = path if os.path.isabs(path) else os.path.join(repo_root, path) |
| return candidate if os.path.exists(candidate) else None |
|
|
|
|
| config_path = _resolve_config_path(config_path_default) |
| config = OmegaConf.load(config_path) if config_path else None |
|
|
| extra_kwargs = {} |
| if config is not None and hasattr(config, "transformer_additional_kwargs"): |
| extra_kwargs = OmegaConf.to_container(config.transformer_additional_kwargs, resolve=True) |
|
|
| transformer = ZImageTransformer2DModel.from_pretrained( |
| model_name, |
| subfolder="transformer", |
| low_cpu_mem_usage=True, |
| torch_dtype=weight_dtype, |
| **({"transformer_additional_kwargs": extra_kwargs} if extra_kwargs else {}), |
| ).to(weight_dtype).to(device) |
|
|
| if transformer_path is not None: |
| print(f"From checkpoint: {transformer_path}") |
| if transformer_path.endswith("safetensors"): |
| from safetensors.torch import load_file |
| state_dict = load_file(transformer_path) |
| else: |
| state_dict = torch.load(transformer_path, map_location="cpu") |
| state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
| m, u = transformer.load_state_dict(state_dict, strict=False) |
| print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
| Chosen_Scheduler = { |
| "Flow": FlowMatchEulerDiscreteScheduler, |
| "Flow_Unipc": FlowUniPCMultistepScheduler, |
| "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| }[sampler_name] |
|
|
| scheduler = Chosen_Scheduler.from_pretrained( |
| model_name, |
| subfolder="scheduler" |
| ) |
|
|
| height, width = sample_size |
| vae_scale_factor = 8 |
| vae_scale = vae_scale_factor * 2 |
|
|
| if height % vae_scale != 0 or width % vae_scale != 0: |
| raise ValueError(f"Height/Width must be divisible by {vae_scale}") |
|
|
| _guidance_scale = guidance_scale |
| _joint_attention_kwargs = None |
| _interrupt = False |
| _cfg_normalization = False |
| _cfg_truncation = 1.0 |
|
|
| prompt_embeds = None |
| if isinstance(prompt, str): |
| batch_size = 1 |
| elif isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = len(prompt_embeds) |
|
|
| batch_size = 1 |
|
|
| weight_dtype = text_encoder.dtype |
| num_channels_latents = 16 |
| inpaint_latent = None |
|
|
| from diffusers.image_processor import VaeImageProcessor |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) |
|
|
| vae = AutoencoderKL.from_pretrained( |
| model_name, |
| subfolder="vae" |
| ).to(weight_dtype) |
| vae.eval() |
| vae_config_shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| vae_config_scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
|
|
| |
| control_context_scale_tensor = torch.tensor([control_context_scale], device=device, dtype=torch.float32) |
|
|
| onnx_transformer_enabled = use_transformer_onnx and os.path.exists(transformer_body_onnx_path) |
| if use_transformer_onnx and not onnx_transformer_enabled: |
| print(f"ONNX transformer requested but missing file, fallback to torch. body={transformer_body_onnx_path}") |
| else: |
| print(f"[DEBUG] ONNX providers: {onnx_providers}") |
| print(f"[DEBUG] transformer body onnx path: {transformer_body_onnx_path}") |
|
|
| do_classifier_free_guidance = False |
| negative_prompt_embeds = None |
| max_sequence_length = 128 |
|
|
| prompt_embeds, negative_prompt_embeds = encode_prompt( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| device=device, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| num_images_per_prompt = 1 |
| latents = None |
|
|
| latents = prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| torch.float32, |
| device, |
| torch.Generator(device=device).manual_seed(seed), |
| latents, |
| ) |
|
|
| if num_images_per_prompt > 1: |
| prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] |
| if do_classifier_free_guidance and negative_prompt_embeds: |
| negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] |
|
|
| actual_batch_size = batch_size * num_images_per_prompt |
| image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
|
|
|
|
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ): |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
|
|
| mu = calculate_shift( |
| image_seq_len, |
| scheduler.config.get("base_image_seq_len", 256), |
| scheduler.config.get("max_image_seq_len", 4096), |
| scheduler.config.get("base_shift", 0.5), |
| scheduler.config.get("max_shift", 1.15), |
| ) |
| scheduler.sigma_min = 0.0 |
| scheduler_kwargs = {"mu": mu} |
|
|
|
|
| import inspect |
|
|
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| if timesteps is not None and sigmas is not None: |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" timestep schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" sigmas schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| timesteps, num_inference_steps = retrieve_timesteps( |
| scheduler, |
| num_inference_steps=num_inference_steps, |
| device=device, |
| **scheduler_kwargs, |
| ) |
|
|
| num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) |
|
|
| for i, t in enumerate(timesteps): |
| timestep = t.expand(latents.shape[0]) |
| timestep_model_input = (1000 - timestep) / 1000 |
|
|
| latent_model_input = latents.to(transformer.dtype) |
| prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds) |
|
|
| if not _printed_onnx_debug: |
| print(f"[DEBUG] runtime latent shape: {latent_model_input.unsqueeze(2).shape}, dtype={latent_model_input.dtype}") |
| print(f"[DEBUG] runtime prompt shape: {prompt_embeds_tensor.shape}, dtype={prompt_embeds_tensor.dtype}") |
|
|
| if onnx_transformer_enabled: |
| onnx_latent = latent_model_input.unsqueeze(2).to(dtype=onnx_export_dtype) |
| onnx_prompt = prompt_embeds_tensor.to(dtype=onnx_export_dtype) |
|
|
| if not _onnx_shape_compatible(transformer_body_onnx_path, onnx_providers, onnx_latent.shape, onnx_prompt.shape, verbose=True): |
| print("ONNX transformer 输入尺寸与当前推理不匹配,回退到 torch。") |
| onnx_transformer_enabled = False |
| else: |
| body_inputs = { |
| "latent_model_input": onnx_latent, |
| "timestep": timestep_model_input.to(dtype=torch.float32), |
| "prompt_embeds": onnx_prompt, |
| } |
| model_out = ort_inference(transformer_body_onnx_path, body_inputs, providers=onnx_providers)[0] |
| model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=onnx_export_dtype) |
| model_out_list = list(model_out_tensor) |
| if not onnx_transformer_enabled: |
| latent_model_input_list = list(latent_model_input.unsqueeze(2).unbind(dim=0)) |
| timestep_model_input = timestep_model_input.to(dtype=transformer.dtype) |
|
|
| if device.type == "cuda": |
| with torch.autocast(device_type="cuda", dtype=transformer.dtype): |
| model_out_list = transformer( |
| latent_model_input_list, |
| timestep_model_input, |
| prompt_embeds, |
| patch_size=2, |
| f_patch_size=1, |
| )[0] |
| else: |
| model_out_list = transformer( |
| latent_model_input_list, |
| timestep_model_input, |
| prompt_embeds, |
| patch_size=2, |
| f_patch_size=1, |
| )[0] |
|
|
| noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) |
| noise_pred = noise_pred.squeeze(2) |
| noise_pred = -noise_pred |
|
|
| latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] |
| assert latents.dtype == torch.float32 |
|
|
| if not _printed_onnx_debug: |
| _printed_onnx_debug = True |
|
|
| |
| output_type = "pil" |
| if output_type == "latent": |
| image = latents |
| else: |
| latents = latents.to(vae.dtype) |
| latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
|
|
| if os.path.exists(vae_decoder_onnx_path): |
| image = ort_inference( |
| vae_decoder_onnx_path, |
| {"latents": latents}, |
| )[0] |
| image = torch.from_numpy(image).to(device=device, dtype=vae.dtype) |
| else: |
| with module_to_device(vae, device): |
| image = vae.decode(latents, return_dict=False)[0] |
| image = image_processor.postprocess(image, output_type=output_type) |
|
|
| sample = image |
|
|
| def save_results(): |
| if not os.path.exists(save_path): |
| os.makedirs(save_path, exist_ok=True) |
| index = len([path for path in os.listdir(save_path)]) + 1 |
| prefix = str(index).zfill(8) |
| video_path = os.path.join(save_path, prefix + ".png") |
| image = sample[0] |
| image.save(video_path) |
|
|
| save_results() |
| logger.info(f"Saved image to {save_path}") |
|
|