Z-Image-Turbo / VideoX-Fun /examples /z_image_fun /launcher_no_controlnet.py
yongqiang
initialize this repo
ba96580
import os
import sys
import numpy as np
import torch
import onnxruntime as ort
from contextlib import contextmanager
from typing import List, Optional, Union
from omegaconf import OmegaConf
from PIL import Image
from loguru import logger
current_file_path = os.path.abspath(__file__)
project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))]
for project_root in project_roots:
sys.path.insert(0, project_root) if project_root not in sys.path else None
repo_root = project_roots[-1]
onnx_models_dir = os.path.join(repo_root, "onnx-models")
# vae_encoder_onnx_path = ""
# vae_decoder_onnx_path = ""
vae_encoder_onnx_path = os.path.join(onnx_models_dir, "vae_encoder_simp_slim.onnx")
vae_decoder_onnx_path = os.path.join(onnx_models_dir, "vae_decoder_simp_slim.onnx")
use_transformer_onnx = True
transformer_body_onnx_path = os.path.join(onnx_models_dir, "z_image_transformer_body_only_simp_slim.onnx")
onnx_export_dtype = torch.float16
_printed_onnx_debug = False
def _select_onnx_providers():
avail = ort.get_available_providers()
if "CUDAExecutionProvider" in avail:
return ["CUDAExecutionProvider", "CPUExecutionProvider"]
if "AzureExecutionProvider" in avail:
return ["AzureExecutionProvider", "CPUExecutionProvider"]
return ["CPUExecutionProvider"]
onnx_providers = _select_onnx_providers()
from diffusers import FlowMatchEulerDiscreteScheduler
from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler
from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
from videox_fun.models import (AutoencoderKL, AutoTokenizer,
Qwen3ForCausalLM, ZImageTransformer2DModel)
from diffusers.utils.torch_utils import randn_tensor
from videox_fun.utils.utils import get_image_latent
torch.set_grad_enabled(False)
_ort_sessions = {}
def _ort_type_to_torch(type_str: str):
if type_str is None:
return None
if "float16" in type_str:
return torch.float16
if "bfloat16" in type_str:
return torch.bfloat16
if "float" in type_str:
return torch.float32
return None
def ort_inference(onnx_model_path: str, inputs: dict, providers=None):
providers = providers or ["CPUExecutionProvider"]
cache_key = (onnx_model_path, tuple(providers))
if cache_key not in _ort_sessions:
if not os.path.exists(onnx_model_path):
raise FileNotFoundError(f"ONNX model not found: {onnx_model_path}")
_ort_sessions[cache_key] = ort.InferenceSession(onnx_model_path, providers=providers)
session = _ort_sessions[cache_key]
# 按 ONNX 输入类型对齐 dtype,避免 float/bfloat16 转换导致的类型冲突
input_type_map = {i.name: _ort_type_to_torch(getattr(i, "type", None)) for i in session.get_inputs()}
inputs_onnx = {}
for k, v in inputs.items():
target_torch_dtype = input_type_map.get(k)
if isinstance(v, np.ndarray):
if target_torch_dtype is not None:
np_dtype = np.float16 if target_torch_dtype == torch.float16 else np.float32
if v.dtype != np_dtype:
v = v.astype(np_dtype)
inputs_onnx[k] = v
elif torch.is_tensor(v):
if target_torch_dtype is not None and v.dtype != target_torch_dtype:
v = v.to(dtype=target_torch_dtype)
elif v.dtype == torch.bfloat16:
# numpy 不支持 bfloat16,主动转换为 float32
v = v.to(dtype=torch.float32)
inputs_onnx[k] = v.detach().to("cpu").numpy()
else:
raise TypeError(f"Unsupported input type for key {k}: {type(v)}")
return session.run(None, inputs_onnx)
def _infer_module_device(module: torch.nn.Module) -> torch.device:
param = next(module.parameters(), None)
if param is not None:
return param.device
buffer = next(module.buffers(), None)
if buffer is not None:
return buffer.device
return torch.device("cpu")
@contextmanager
def module_to_device(module: torch.nn.Module, target_device: torch.device):
if module is None:
yield module
return
original_device = _infer_module_device(module)
target_device = target_device or original_device
needs_move = original_device != target_device
moved_to_cuda = needs_move and target_device.type == "cuda"
if needs_move:
module.to(target_device)
try:
yield module
finally:
if needs_move:
module.to(original_device)
if moved_to_cuda and torch.cuda.is_available():
cache_device = target_device.index
if cache_device is None:
cache_device = torch.cuda.current_device()
with torch.cuda.device(cache_device):
torch.cuda.empty_cache()
# Config and model path(纯 body,无 control)
config_path_default = "config/z_image/z_image.yaml"
model_name = "models/Diffusion_Transformer/Z-Image-Turbo/"
# Use torch.float16 if GPU does not support torch.bfloat16
weight_dtype = torch.bfloat16
control_context_scale = 0.75
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# Get tokenizer and text_encoder
tokenizer = AutoTokenizer.from_pretrained(
model_name, subfolder="tokenizer"
)
text_encoder = Qwen3ForCausalLM.from_pretrained(
model_name, subfolder="text_encoder", torch_dtype=weight_dtype,
low_cpu_mem_usage=True,
)
text_encoder.eval()
def _encode_prompt(
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
max_sequence_length: int = 512,
) -> List[torch.FloatTensor]:
device = device or _infer_module_device(text_encoder)
if prompt_embeds is not None:
return prompt_embeds
if isinstance(prompt, str):
prompt = [prompt]
for i, prompt_item in enumerate(prompt):
messages = [
{"role": "user", "content": prompt_item},
]
prompt_item = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True,
)
prompt[i] = prompt_item
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
prompt_masks = text_inputs.attention_mask.to(device).bool()
with module_to_device(text_encoder, device):
prompt_embeds = text_encoder(
input_ids=text_input_ids,
attention_mask=prompt_masks,
output_hidden_states=True,
).hidden_states[-2]
embeddings_list = []
for i in range(len(prompt_embeds)):
# embeddings_list.append(prompt_embeds[i][prompt_masks[i]])
embeddings_list.append(prompt_embeds[i])
return embeddings_list
def encode_prompt(
prompt: Union[str, List[str]],
device: Optional[torch.device] = None,
do_classifier_free_guidance: bool = True,
negative_prompt: Optional[Union[str, List[str]]] = None,
prompt_embeds: Optional[List[torch.FloatTensor]] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
max_sequence_length: int = 512,
):
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt_embeds = _encode_prompt(
prompt=prompt,
device=device,
prompt_embeds=prompt_embeds,
max_sequence_length=max_sequence_length,
)
if do_classifier_free_guidance:
if negative_prompt is None:
negative_prompt = ["" for _ in prompt]
else:
negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt
assert len(prompt) == len(negative_prompt)
negative_prompt_embeds = _encode_prompt(
prompt=negative_prompt,
device=device,
prompt_embeds=negative_prompt_embeds,
max_sequence_length=max_sequence_length,
)
else:
negative_prompt_embeds = []
return prompt_embeds, negative_prompt_embeds
def _stack_prompt_embeddings(prompt_embeds_input):
if isinstance(prompt_embeds_input, list):
return torch.stack(prompt_embeds_input, dim=0)
return prompt_embeds_input
def _onnx_shape_compatible(model_path: str, providers, latent_shape, prompt_shape, verbose=False):
try:
sess = ort.InferenceSession(model_path, providers=providers)
inputs = {i.name: i for i in sess.get_inputs()}
ok = True
mismatch_msgs = []
if "latent_model_input" in inputs:
shape = inputs["latent_model_input"].shape
exp_h, exp_w = shape[3], shape[4]
if isinstance(exp_h, int) and exp_h != latent_shape[3]:
ok = False
mismatch_msgs.append(f"latent_h expected {exp_h}, got {latent_shape[3]}")
if isinstance(exp_w, int) and exp_w != latent_shape[4]:
ok = False
mismatch_msgs.append(f"latent_w expected {latent_w} got {latent_shape[4]}")
if "prompt_embeds" in inputs:
pshape = inputs["prompt_embeds"].shape
exp_seq = pshape[1]
if isinstance(exp_seq, int) and exp_seq != prompt_shape[1]:
ok = False
mismatch_msgs.append(f"seq_len expected {exp_seq}, got {prompt_shape[1]}")
if verbose and (not ok or mismatch_msgs):
print(f"[DEBUG] ONNX shape check for {model_path}")
print(f" providers={providers}")
print(f" model latent shape={inputs.get('latent_model_input').shape if 'latent_model_input' in inputs else 'n/a'}")
print(f" model prompt shape={inputs.get('prompt_embeds').shape if 'prompt_embeds' in inputs else 'n/a'}")
print(f" runtime latent shape={latent_shape}, prompt shape={prompt_shape}")
if mismatch_msgs:
print(f" mismatch: {', '.join(mismatch_msgs)}")
return ok
except Exception as exc:
print(f"ONNX shape check failed for {model_path}: {exc}")
return True
def prepare_latents(
batch_size,
num_channels_latents,
height,
width,
dtype,
device,
generator,
latents=None,
):
height = 2 * (int(height) // (vae_scale_factor * 2))
width = 2 * (int(width) // (vae_scale_factor * 2))
shape = (batch_size, num_channels_latents, height, width)
if latents is None:
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
else:
if latents.shape != shape:
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
latents = latents.to(device)
return latents
# GPU memory mode (保留占位,当前脚本默认全量加载)
GPU_memory_mode = "model_full_load"
ulysses_degree = 1
ring_degree = 1
fsdp_dit = False
fsdp_text_encoder = False
compile_dit = False
sampler_name = "Flow"
transformer_path = None
vae_path = None
lora_path = None
sample_size = [512, 512] # H, W
prompt = "(masterpiece, best quality), 1 girl on the beach"
negative_prompt = " "
guidance_scale = 0.0
seed = 42
num_inference_steps = 9
lora_weight = 0.55
save_path = "samples/z-image-t2i-nocontrol"
def _resolve_config_path(path: str) -> Optional[str]:
candidate = path if os.path.isabs(path) else os.path.join(repo_root, path)
return candidate if os.path.exists(candidate) else None
config_path = _resolve_config_path(config_path_default)
config = OmegaConf.load(config_path) if config_path else None
extra_kwargs = {}
if config is not None and hasattr(config, "transformer_additional_kwargs"):
extra_kwargs = OmegaConf.to_container(config.transformer_additional_kwargs, resolve=True)
transformer = ZImageTransformer2DModel.from_pretrained(
model_name,
subfolder="transformer",
low_cpu_mem_usage=True,
torch_dtype=weight_dtype,
**({"transformer_additional_kwargs": extra_kwargs} if extra_kwargs else {}),
).to(weight_dtype).to(device)
if transformer_path is not None:
print(f"From checkpoint: {transformer_path}")
if transformer_path.endswith("safetensors"):
from safetensors.torch import load_file
state_dict = load_file(transformer_path)
else:
state_dict = torch.load(transformer_path, map_location="cpu")
state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict
m, u = transformer.load_state_dict(state_dict, strict=False)
print(f"missing keys: {len(m)}, unexpected keys: {len(u)}")
Chosen_Scheduler = {
"Flow": FlowMatchEulerDiscreteScheduler,
"Flow_Unipc": FlowUniPCMultistepScheduler,
"Flow_DPM++": FlowDPMSolverMultistepScheduler,
}[sampler_name]
scheduler = Chosen_Scheduler.from_pretrained(
model_name,
subfolder="scheduler"
)
height, width = sample_size
vae_scale_factor = 8
vae_scale = vae_scale_factor * 2
if height % vae_scale != 0 or width % vae_scale != 0:
raise ValueError(f"Height/Width must be divisible by {vae_scale}")
_guidance_scale = guidance_scale
_joint_attention_kwargs = None
_interrupt = False
_cfg_normalization = False
_cfg_truncation = 1.0
prompt_embeds = None
if isinstance(prompt, str):
batch_size = 1
elif isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = len(prompt_embeds)
batch_size = 1
weight_dtype = text_encoder.dtype
num_channels_latents = 16
inpaint_latent = None
from diffusers.image_processor import VaeImageProcessor
image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2)
vae = AutoencoderKL.from_pretrained(
model_name,
subfolder="vae"
).to(weight_dtype)
vae.eval()
vae_config_shift_factor = getattr(vae.config, "shift_factor", 0.0)
vae_config_scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
# No control image path in this launcher
control_context_scale_tensor = torch.tensor([control_context_scale], device=device, dtype=torch.float32)
onnx_transformer_enabled = use_transformer_onnx and os.path.exists(transformer_body_onnx_path)
if use_transformer_onnx and not onnx_transformer_enabled:
print(f"ONNX transformer requested but missing file, fallback to torch. body={transformer_body_onnx_path}")
else:
print(f"[DEBUG] ONNX providers: {onnx_providers}")
print(f"[DEBUG] transformer body onnx path: {transformer_body_onnx_path}")
do_classifier_free_guidance = False
negative_prompt_embeds = None
max_sequence_length = 128
prompt_embeds, negative_prompt_embeds = encode_prompt(
prompt=prompt,
negative_prompt=negative_prompt,
do_classifier_free_guidance=do_classifier_free_guidance,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
device=device,
max_sequence_length=max_sequence_length,
)
num_images_per_prompt = 1
latents = None
latents = prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
torch.float32,
device,
torch.Generator(device=device).manual_seed(seed),
latents,
)
if num_images_per_prompt > 1:
prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)]
if do_classifier_free_guidance and negative_prompt_embeds:
negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)]
actual_batch_size = batch_size * num_images_per_prompt
image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2)
def calculate_shift(
image_seq_len,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.15,
):
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
mu = calculate_shift(
image_seq_len,
scheduler.config.get("base_image_seq_len", 256),
scheduler.config.get("max_image_seq_len", 4096),
scheduler.config.get("base_shift", 0.5),
scheduler.config.get("max_shift", 1.15),
)
scheduler.sigma_min = 0.0
scheduler_kwargs = {"mu": mu}
import inspect
def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
timesteps: Optional[List[int]] = None,
sigmas: Optional[List[float]] = None,
**kwargs,
):
if timesteps is not None and sigmas is not None:
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
if timesteps is not None:
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accepts_timesteps:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" timestep schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
elif sigmas is not None:
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
if not accept_sigmas:
raise ValueError(
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
f" sigmas schedules. Please check whether you are using the correct scheduler."
)
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
timesteps = scheduler.timesteps
num_inference_steps = len(timesteps)
else:
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
timesteps = scheduler.timesteps
return timesteps, num_inference_steps
timesteps, num_inference_steps = retrieve_timesteps(
scheduler,
num_inference_steps=num_inference_steps,
device=device,
**scheduler_kwargs,
)
num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0)
for i, t in enumerate(timesteps):
timestep = t.expand(latents.shape[0])
timestep_model_input = (1000 - timestep) / 1000
latent_model_input = latents.to(transformer.dtype)
prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds)
if not _printed_onnx_debug:
print(f"[DEBUG] runtime latent shape: {latent_model_input.unsqueeze(2).shape}, dtype={latent_model_input.dtype}")
print(f"[DEBUG] runtime prompt shape: {prompt_embeds_tensor.shape}, dtype={prompt_embeds_tensor.dtype}")
if onnx_transformer_enabled:
onnx_latent = latent_model_input.unsqueeze(2).to(dtype=onnx_export_dtype)
onnx_prompt = prompt_embeds_tensor.to(dtype=onnx_export_dtype)
if not _onnx_shape_compatible(transformer_body_onnx_path, onnx_providers, onnx_latent.shape, onnx_prompt.shape, verbose=True):
print("ONNX transformer 输入尺寸与当前推理不匹配,回退到 torch。")
onnx_transformer_enabled = False
else:
body_inputs = {
"latent_model_input": onnx_latent,
"timestep": timestep_model_input.to(dtype=torch.float32),
"prompt_embeds": onnx_prompt,
}
model_out = ort_inference(transformer_body_onnx_path, body_inputs, providers=onnx_providers)[0]
model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=onnx_export_dtype)
model_out_list = list(model_out_tensor)
if not onnx_transformer_enabled:
latent_model_input_list = list(latent_model_input.unsqueeze(2).unbind(dim=0))
timestep_model_input = timestep_model_input.to(dtype=transformer.dtype)
if device.type == "cuda":
with torch.autocast(device_type="cuda", dtype=transformer.dtype):
model_out_list = transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds,
patch_size=2,
f_patch_size=1,
)[0]
else:
model_out_list = transformer(
latent_model_input_list,
timestep_model_input,
prompt_embeds,
patch_size=2,
f_patch_size=1,
)[0]
noise_pred = torch.stack([t.float() for t in model_out_list], dim=0)
noise_pred = noise_pred.squeeze(2)
noise_pred = -noise_pred
latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0]
assert latents.dtype == torch.float32
if not _printed_onnx_debug:
_printed_onnx_debug = True
# Decode
output_type = "pil"
if output_type == "latent":
image = latents
else:
latents = latents.to(vae.dtype)
latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor
if os.path.exists(vae_decoder_onnx_path):
image = ort_inference(
vae_decoder_onnx_path,
{"latents": latents},
)[0]
image = torch.from_numpy(image).to(device=device, dtype=vae.dtype)
else:
with module_to_device(vae, device):
image = vae.decode(latents, return_dict=False)[0]
image = image_processor.postprocess(image, output_type=output_type)
sample = image
def save_results():
if not os.path.exists(save_path):
os.makedirs(save_path, exist_ok=True)
index = len([path for path in os.listdir(save_path)]) + 1
prefix = str(index).zfill(8)
video_path = os.path.join(save_path, prefix + ".png")
image = sample[0]
image.save(video_path)
save_results()
logger.info(f"Saved image to {save_path}")