| | import os |
| | import sys |
| | import numpy as np |
| | import torch |
| | import onnxruntime as ort |
| | from contextlib import contextmanager |
| |
|
| | from omegaconf import OmegaConf |
| | from PIL import Image |
| | from loguru import logger |
| |
|
| | current_file_path = os.path.abspath(__file__) |
| | project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
| | for project_root in project_roots: |
| | sys.path.insert(0, project_root) if project_root not in sys.path else None |
| | repo_root = project_roots[-1] |
| | onnx_models_dir = os.path.join(repo_root, "onnx-models") |
| | vae_encoder_onnx_path = os.path.join(onnx_models_dir, "vae_encoder_simp_slim.onnx") |
| | vae_decoder_onnx_path = os.path.join(onnx_models_dir, "vae_decoder_simp_slim.onnx") |
| | |
| | |
| | use_transformer_onnx = True |
| | use_controlnet = False |
| | controlnet_onnx_path = os.path.join(onnx_models_dir, "z_image_controlnet_simp_slim.onnx") |
| | transformer_body_onnx_path = os.path.join(onnx_models_dir, "z_image_transformer_body_simp_slim.onnx") |
| | onnx_export_dtype = torch.float16 |
| | _printed_onnx_debug = False |
| | _body_input_cache = {} |
| |
|
| |
|
| | def _select_onnx_providers(): |
| | avail = ort.get_available_providers() |
| | if "CUDAExecutionProvider" in avail: |
| | return ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| | if "AzureExecutionProvider" in avail: |
| | return ["AzureExecutionProvider", "CPUExecutionProvider"] |
| | return ["CPUExecutionProvider"] |
| |
|
| |
|
| | onnx_providers = _select_onnx_providers() |
| |
|
| | from diffusers import FlowMatchEulerDiscreteScheduler |
| | from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| | from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
| |
|
| | from videox_fun.models import (AutoencoderKL, AutoTokenizer, |
| | Qwen3ForCausalLM, ZImageControlTransformer2DModel) |
| | from typing import List, Optional, Union |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from videox_fun.utils.utils import get_image_latent |
| |
|
| | torch.set_grad_enabled(False) |
| |
|
| | _ort_sessions = {} |
| |
|
| |
|
| | def _ort_type_to_torch(type_str: str): |
| | if type_str is None: |
| | return None |
| | if "float16" in type_str: |
| | return torch.float16 |
| | if "bfloat16" in type_str: |
| | return torch.bfloat16 |
| | if "float" in type_str: |
| | return torch.float32 |
| | return None |
| |
|
| |
|
| | def ort_inference(onnx_model_path: str, inputs: dict, providers=None): |
| | providers = providers or ["CPUExecutionProvider"] |
| | cache_key = (onnx_model_path, tuple(providers)) |
| | if cache_key not in _ort_sessions: |
| | if not os.path.exists(onnx_model_path): |
| | raise FileNotFoundError(f"ONNX model not found: {onnx_model_path}") |
| | _ort_sessions[cache_key] = ort.InferenceSession(onnx_model_path, providers=providers) |
| | session = _ort_sessions[cache_key] |
| |
|
| | |
| | input_type_map = {i.name: _ort_type_to_torch(getattr(i, "type", None)) for i in session.get_inputs()} |
| | inputs_onnx = {} |
| | for k, v in inputs.items(): |
| | target_torch_dtype = input_type_map.get(k) |
| | if isinstance(v, np.ndarray): |
| | if target_torch_dtype is not None: |
| | np_dtype = np.float16 if target_torch_dtype == torch.float16 else np.float32 |
| | if v.dtype != np_dtype: |
| | v = v.astype(np_dtype) |
| | inputs_onnx[k] = v |
| | elif torch.is_tensor(v): |
| | if target_torch_dtype is not None and v.dtype != target_torch_dtype: |
| | v = v.to(dtype=target_torch_dtype) |
| | elif v.dtype == torch.bfloat16: |
| | v = v.to(dtype=torch.float32) |
| | inputs_onnx[k] = v.detach().to("cpu").numpy() |
| | else: |
| | raise TypeError(f"Unsupported input type for key {k}: {type(v)}") |
| | return session.run(None, inputs_onnx) |
| |
|
| |
|
| | def _infer_module_device(module: torch.nn.Module) -> torch.device: |
| | param = next(module.parameters(), None) |
| | if param is not None: |
| | return param.device |
| | buffer = next(module.buffers(), None) |
| | if buffer is not None: |
| | return buffer.device |
| | return torch.device("cpu") |
| |
|
| |
|
| | @contextmanager |
| | def module_to_device(module: torch.nn.Module, target_device: torch.device): |
| | if module is None: |
| | yield module |
| | return |
| | original_device = _infer_module_device(module) |
| | target_device = target_device or original_device |
| | needs_move = original_device != target_device |
| | moved_to_cuda = needs_move and target_device.type == "cuda" |
| | if needs_move: |
| | module.to(target_device) |
| | try: |
| | yield module |
| | finally: |
| | if needs_move: |
| | module.to(original_device) |
| | if moved_to_cuda and torch.cuda.is_available(): |
| | cache_device = target_device.index |
| | if cache_device is None: |
| | cache_device = torch.cuda.current_device() |
| | with torch.cuda.device(cache_device): |
| | torch.cuda.empty_cache() |
| |
|
| |
|
| | |
| | config_path = "config/z_image/z_image_control.yaml" |
| | model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" |
| |
|
| | |
| | |
| | weight_dtype = torch.bfloat16 |
| | |
| | control_image = "asset/pose_1024x1024.png" |
| | control_context_scale = 0.75 |
| |
|
| | device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu') |
| |
|
| | |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | model_name, subfolder="tokenizer" |
| | ) |
| | text_encoder = Qwen3ForCausalLM.from_pretrained( |
| | model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
| | low_cpu_mem_usage=True, |
| | ) |
| | text_encoder.eval() |
| |
|
| | def _encode_prompt( |
| | prompt: Union[str, List[str]], |
| | device: Optional[torch.device] = None, |
| | prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| | max_sequence_length: int = 512, |
| | ) -> List[torch.FloatTensor]: |
| | device = device or torch.device("cpu") |
| |
|
| | if prompt_embeds is not None: |
| | return prompt_embeds |
| |
|
| | if isinstance(prompt, str): |
| | prompt = [prompt] |
| |
|
| | for i, prompt_item in enumerate(prompt): |
| | messages = [ |
| | {"role": "user", "content": prompt_item}, |
| | ] |
| | prompt_item = tokenizer.apply_chat_template( |
| | messages, |
| | tokenize=False, |
| | add_generation_prompt=True, |
| | enable_thinking=True, |
| | ) |
| | prompt[i] = prompt_item |
| |
|
| | text_inputs = tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=max_sequence_length, |
| | truncation=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | text_input_ids = text_inputs.input_ids.to(device) |
| | prompt_masks = text_inputs.attention_mask.to(device).bool() |
| |
|
| | with module_to_device(text_encoder, device): |
| | prompt_embeds = text_encoder( |
| | input_ids=text_input_ids, |
| | attention_mask=prompt_masks, |
| | output_hidden_states=True, |
| | ).hidden_states[-2] |
| |
|
| | embeddings_list = [] |
| |
|
| | for i in range(len(prompt_embeds)): |
| | |
| | embeddings_list.append(prompt_embeds[i]) |
| |
|
| | return embeddings_list |
| |
|
| |
|
| | def encode_prompt( |
| | prompt: Union[str, List[str]], |
| | device: Optional[torch.device] = None, |
| | do_classifier_free_guidance: bool = True, |
| | negative_prompt: Optional[Union[str, List[str]]] = None, |
| | prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| | negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| | max_sequence_length: int = 512, |
| | ): |
| | prompt = [prompt] if isinstance(prompt, str) else prompt |
| | prompt_embeds = _encode_prompt( |
| | prompt=prompt, |
| | device=device, |
| | prompt_embeds=prompt_embeds, |
| | max_sequence_length=max_sequence_length, |
| | ) |
| |
|
| | if do_classifier_free_guidance: |
| | if negative_prompt is None: |
| | negative_prompt = ["" for _ in prompt] |
| | else: |
| | negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
| | assert len(prompt) == len(negative_prompt) |
| | negative_prompt_embeds = _encode_prompt( |
| | prompt=negative_prompt, |
| | device=device, |
| | prompt_embeds=negative_prompt_embeds, |
| | max_sequence_length=max_sequence_length, |
| | ) |
| | else: |
| | negative_prompt_embeds = [] |
| | return prompt_embeds, negative_prompt_embeds |
| |
|
| |
|
| | def _stack_prompt_embeddings(prompt_embeds_input): |
| | if isinstance(prompt_embeds_input, list): |
| | return torch.stack(prompt_embeds_input, dim=0) |
| | return prompt_embeds_input |
| |
|
| |
|
| | def _onnx_shape_compatible(model_path: str, providers, latent_shape, prompt_shape, verbose=False): |
| | try: |
| | sess = ort.InferenceSession(model_path, providers=providers) |
| | inputs = {i.name: i for i in sess.get_inputs()} |
| | ok = True |
| | mismatch_msgs = [] |
| | if "latent_model_input" in inputs: |
| | shape = inputs["latent_model_input"].shape |
| | exp_h, exp_w = shape[3], shape[4] |
| | if isinstance(exp_h, int) and exp_h != latent_shape[3]: |
| | ok = False |
| | mismatch_msgs.append(f"latent_h expected {exp_h}, got {latent_shape[3]}") |
| | if isinstance(exp_w, int) and exp_w != latent_shape[4]: |
| | ok = False |
| | mismatch_msgs.append(f"latent_w expected {exp_w}, got {latent_shape[4]}") |
| | if "prompt_embeds" in inputs: |
| | pshape = inputs["prompt_embeds"].shape |
| | exp_seq = pshape[1] |
| | if isinstance(exp_seq, int) and exp_seq != prompt_shape[1]: |
| | ok = False |
| | mismatch_msgs.append(f"seq_len expected {exp_seq}, got {prompt_shape[1]}") |
| | if verbose and (not ok or mismatch_msgs): |
| | print(f"[DEBUG] ONNX shape check for {model_path}") |
| | print(f" providers={providers}") |
| | print(f" model latent shape={inputs.get('latent_model_input').shape if 'latent_model_input' in inputs else 'n/a'}") |
| | print(f" model prompt shape={inputs.get('prompt_embeds').shape if 'prompt_embeds' in inputs else 'n/a'}") |
| | print(f" runtime latent shape={latent_shape}, prompt shape={prompt_shape}") |
| | if mismatch_msgs: |
| | print(f" mismatch: {', '.join(mismatch_msgs)}") |
| | return ok |
| | except Exception as exc: |
| | print(f"ONNX shape check failed for {model_path}: {exc}") |
| | return True |
| |
|
| |
|
| | def prepare_latents( |
| | batch_size, |
| | num_channels_latents, |
| | height, |
| | width, |
| | dtype, |
| | device, |
| | generator, |
| | latents=None, |
| | ): |
| | height = 2 * (int(height) // (vae_scale_factor * 2)) |
| | width = 2 * (int(width) // (vae_scale_factor * 2)) |
| |
|
| | shape = (batch_size, num_channels_latents, height, width) |
| |
|
| | if latents is None: |
| | latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| | else: |
| | if latents.shape != shape: |
| | raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
| | latents = latents.to(device) |
| | return latents |
| |
|
| |
|
| | def _make_zero_control_hints(model_path: str, providers, device, dtype): |
| | cache_key = (model_path, tuple(providers)) |
| | if cache_key not in _body_input_cache: |
| | sess = ort.InferenceSession(model_path, providers=providers) |
| | inputs = {i.name: i for i in sess.get_inputs()} |
| | if "control_hints" not in inputs: |
| | raise RuntimeError("Body ONNX 未找到 control_hints 输入,无法构造零填充。") |
| | _body_input_cache[cache_key] = inputs["control_hints"].shape |
| | shape = _body_input_cache[cache_key] |
| | |
| | resolved = [] |
| | for d in shape: |
| | if isinstance(d, int): |
| | resolved.append(d) |
| | else: |
| | resolved.append(1) |
| | return torch.zeros(resolved, device=device, dtype=dtype) |
| |
|
| |
|
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | GPU_memory_mode = "model_cpu_offload" |
| | |
| | |
| | |
| | |
| | ulysses_degree = 1 |
| | ring_degree = 1 |
| | |
| | fsdp_dit = False |
| | fsdp_text_encoder = False |
| | |
| | |
| | compile_dit = False |
| |
|
| |
|
| | |
| |
|
| | |
| | sampler_name = "Flow" |
| |
|
| | |
| | transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" |
| | vae_path = None |
| | lora_path = None |
| |
|
| | |
| | sample_size = [512, 512] |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | prompt = "(masterpiece, best quality, ultra detailed, 8k, CG unity wallpaper),1 young beautiful girl, full body, official art, extremely detailed, highly detailed, 1 girl, aqua eyes, light smile, grey hair, hair flower,bracelet, choker, ribbon, JK, looking at viewer, on the beach, in summer," |
| | negative_prompt = " " |
| | guidance_scale = 0.00 |
| | seed = 43 |
| | num_inference_steps = 9 |
| | lora_weight = 0.55 |
| | save_path = "samples/z-image-t2i-control" |
| |
|
| | config = OmegaConf.load(config_path) |
| |
|
| | transformer = ZImageControlTransformer2DModel.from_pretrained( |
| | model_name, |
| | subfolder="transformer", |
| | low_cpu_mem_usage=True, |
| | torch_dtype=weight_dtype, |
| | transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
| | ).to(weight_dtype).to(device) |
| |
|
| | if transformer_path is not None: |
| | print(f"From checkpoint: {transformer_path}") |
| | if transformer_path.endswith("safetensors"): |
| | from safetensors.torch import load_file, safe_open |
| | state_dict = load_file(transformer_path) |
| | else: |
| | state_dict = torch.load(transformer_path, map_location="cpu") |
| | state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
| |
|
| | m, u = transformer.load_state_dict(state_dict, strict=False) |
| | print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
| |
|
| |
|
| | |
| | Chosen_Scheduler = { |
| | "Flow": FlowMatchEulerDiscreteScheduler, |
| | "Flow_Unipc": FlowUniPCMultistepScheduler, |
| | "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| | }[sampler_name] |
| |
|
| | scheduler = Chosen_Scheduler.from_pretrained( |
| | model_name, |
| | subfolder="scheduler" |
| | ) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | generator = torch.Generator(device=device).manual_seed(seed) |
| |
|
| | |
| | |
| |
|
| | |
| | if control_image is not None: |
| | control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
| |
|
| | print(control_image.shape) |
| |
|
| |
|
| | height, width = sample_size |
| |
|
| | vae_scale_factor = 8 |
| | vae_scale = vae_scale_factor * 2 |
| |
|
| | if height % vae_scale != 0: |
| | raise ValueError( |
| | f"Height must be divisible by {vae_scale} (got {height}). " |
| | f"Please adjust the height to a multiple of {vae_scale}." |
| | ) |
| | if width % vae_scale != 0: |
| | raise ValueError( |
| | f"Width must be divisible by {vae_scale} (got {width}). " |
| | f"Please adjust the width to a multiple of {vae_scale}." |
| | ) |
| |
|
| | _guidance_scale = guidance_scale = 0.0 |
| | _joint_attention_kwargs = joint_attention_kwargs = None |
| | _interrupt = False |
| | _cfg_normalization = cfg_normalization = False |
| | _cfg_truncation = cfg_truncation = 1.0 |
| |
|
| |
|
| | |
| | prompt_embeds = None |
| | if prompt is not None and isinstance(prompt, str): |
| | batch_size = 1 |
| | elif prompt is not None and isinstance(prompt, list): |
| | batch_size = len(prompt) |
| | else: |
| | batch_size = len(prompt_embeds) |
| |
|
| | batch_size = 1 |
| |
|
| | weight_dtype = text_encoder.dtype |
| | num_channels_latents = 16 |
| | inpaint_latent = None |
| |
|
| | from diffusers.image_processor import VaeImageProcessor |
| | image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) |
| |
|
| | |
| | vae = AutoencoderKL.from_pretrained( |
| | model_name, |
| | subfolder="vae" |
| | ).to(weight_dtype) |
| | vae.eval() |
| | vae_config_shift_factor = getattr(vae.config, "shift_factor", 0.0) |
| | vae_config_scaling_factor = getattr(vae.config, "scaling_factor", 1.0) |
| |
|
| | if control_image is not None: |
| | control_image = image_processor.preprocess(control_image, height=height, width=width) |
| | control_image = control_image.to(dtype=weight_dtype, device=device) |
| | if os.path.exists(vae_encoder_onnx_path): |
| | control_latents = ort_inference( |
| | vae_encoder_onnx_path, |
| | {"pixel_values": control_image}, |
| | )[0] |
| | control_latents = torch.from_numpy(control_latents).to(device=device, dtype=weight_dtype) |
| | else: |
| | with module_to_device(vae, device): |
| | control_latents = vae.encode(control_image)[0].mode() |
| | control_latents = (control_latents - vae_config_shift_factor) * vae_config_scaling_factor |
| | else: |
| | control_latents = torch.zeros_like(inpaint_latent) |
| |
|
| | control_context = control_latents.unsqueeze(2) |
| | control_context_scale_tensor = torch.tensor([control_context_scale], device=device, dtype=torch.float32) |
| |
|
| | if not use_controlnet: |
| | control_context = torch.zeros_like(control_context) |
| | control_context_scale_tensor = torch.zeros_like(control_context_scale_tensor) |
| |
|
| | onnx_transformer_enabled = use_transformer_onnx and os.path.exists(controlnet_onnx_path) and os.path.exists(transformer_body_onnx_path) |
| | if use_transformer_onnx and not onnx_transformer_enabled: |
| | print(f"ONNX transformer requested but missing files, fallback to torch. controlnet={controlnet_onnx_path}, body={transformer_body_onnx_path}") |
| | else: |
| | print(f"[DEBUG] ONNX providers: {onnx_providers}") |
| | print(f"[DEBUG] controlnet onnx path: {controlnet_onnx_path}") |
| | print(f"[DEBUG] transformer body onnx path: {transformer_body_onnx_path}") |
| |
|
| |
|
| | do_classifier_free_guidance = False |
| | negative_prompt_embeds = None |
| | max_sequence_length = 128 |
| |
|
| | prompt_embeds, negative_prompt_embeds = encode_prompt( |
| | prompt=prompt, |
| | negative_prompt=negative_prompt, |
| | do_classifier_free_guidance=do_classifier_free_guidance, |
| | prompt_embeds=prompt_embeds, |
| | negative_prompt_embeds=negative_prompt_embeds, |
| | device=device, |
| | max_sequence_length=max_sequence_length, |
| | ) |
| |
|
| | num_images_per_prompt = 1 |
| | latents = None |
| |
|
| | """ |
| | (Pdb) latents[0, 0, 100:105, 100] |
| | tensor([-0.9203, 1.3958, 0.8130, -0.5280, -1.9788], device='cuda:0') |
| | """ |
| |
|
| | |
| | latents = prepare_latents( |
| | batch_size * num_images_per_prompt, |
| | num_channels_latents, |
| | height, |
| | width, |
| | torch.float32, |
| | device, |
| | generator, |
| | latents, |
| | ) |
| | print(latents.shape) |
| | |
| |
|
| | |
| | if num_images_per_prompt > 1: |
| | prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] |
| | if do_classifier_free_guidance and negative_prompt_embeds: |
| | negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] |
| |
|
| | actual_batch_size = batch_size * num_images_per_prompt |
| | image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
| |
|
| |
|
| | |
| | def calculate_shift( |
| | image_seq_len, |
| | base_seq_len: int = 256, |
| | max_seq_len: int = 4096, |
| | base_shift: float = 0.5, |
| | max_shift: float = 1.15, |
| | ): |
| | m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| | b = base_shift - m * base_seq_len |
| | mu = image_seq_len * m + b |
| | return mu |
| |
|
| | |
| | mu = calculate_shift( |
| | image_seq_len, |
| | scheduler.config.get("base_image_seq_len", 256), |
| | scheduler.config.get("max_image_seq_len", 4096), |
| | scheduler.config.get("base_shift", 0.5), |
| | scheduler.config.get("max_shift", 1.15), |
| | ) |
| | scheduler.sigma_min = 0.0 |
| | scheduler_kwargs = {"mu": mu} |
| |
|
| |
|
| | import inspect |
| | |
| | def retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps: Optional[int] = None, |
| | device: Optional[Union[str, torch.device]] = None, |
| | timesteps: Optional[List[int]] = None, |
| | sigmas: Optional[List[float]] = None, |
| | **kwargs, |
| | ): |
| | r""" |
| | Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| | custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| | |
| | Args: |
| | scheduler (`SchedulerMixin`): |
| | The scheduler to get timesteps from. |
| | num_inference_steps (`int`): |
| | The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| | must be `None`. |
| | device (`str` or `torch.device`, *optional*): |
| | The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| | timesteps (`List[int]`, *optional*): |
| | Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| | `num_inference_steps` and `sigmas` must be `None`. |
| | sigmas (`List[float]`, *optional*): |
| | Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| | `num_inference_steps` and `timesteps` must be `None`. |
| | |
| | Returns: |
| | `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| | second element is the number of inference steps. |
| | """ |
| | if timesteps is not None and sigmas is not None: |
| | raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| | if timesteps is not None: |
| | accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| | if not accepts_timesteps: |
| | raise ValueError( |
| | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| | f" timestep schedules. Please check whether you are using the correct scheduler." |
| | ) |
| | scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | elif sigmas is not None: |
| | accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| | if not accept_sigmas: |
| | raise ValueError( |
| | f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| | f" sigmas schedules. Please check whether you are using the correct scheduler." |
| | ) |
| | scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | num_inference_steps = len(timesteps) |
| | else: |
| | scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| | timesteps = scheduler.timesteps |
| | return timesteps, num_inference_steps |
| |
|
| |
|
| | sigmas = None |
| | timesteps, num_inference_steps = retrieve_timesteps( |
| | scheduler, |
| | num_inference_steps, |
| | device, |
| | sigmas=sigmas, |
| | **scheduler_kwargs, |
| | ) |
| | num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) |
| | """ |
| | (Pdb) timesteps |
| | tensor([1000.0000, 954.5454, 900.0000, 833.3333, 750.0000, 642.8571, |
| | 500.0000, 300.0000, 0.0000], device='cuda:0') |
| | """ |
| | _num_timesteps = len(timesteps) |
| | |
| | |
| |
|
| | callback_on_step_end = None |
| | callback_on_step_end_tensor_inputs = ['latents'] |
| |
|
| | |
| | |
| | for i, t in enumerate(timesteps): |
| |
|
| | |
| | timestep = t.expand(latents.shape[0]) |
| | timestep = (1000 - timestep) / 1000 |
| | |
| | t_norm = timestep[0].item() |
| |
|
| | |
| | current_guidance_scale = guidance_scale |
| | if ( |
| | do_classifier_free_guidance |
| | and _cfg_truncation is not None |
| | and float(_cfg_truncation) <= 1 |
| | ): |
| | if t_norm > _cfg_truncation: |
| | current_guidance_scale = 0.0 |
| |
|
| | |
| | apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0 |
| |
|
| | if apply_cfg: |
| | latents_typed = latents.to(transformer.dtype) |
| | latent_model_input = latents_typed.repeat(2, 1, 1, 1) |
| | prompt_embeds_model_input = [p.to(transformer.dtype) for p in (prompt_embeds + negative_prompt_embeds)] |
| | timestep_model_input = timestep.repeat(2) |
| | else: |
| | latent_model_input = latents.to(transformer.dtype) |
| | prompt_embeds_model_input = [p.to(transformer.dtype) for p in prompt_embeds] |
| | timestep_model_input = timestep |
| |
|
| | latent_model_input = latent_model_input.unsqueeze(2) |
| | prompt_embeds_tensor = _stack_prompt_embeddings(prompt_embeds_model_input) |
| |
|
| | if not _printed_onnx_debug: |
| | print(f"[DEBUG] runtime latent shape: {latent_model_input.shape}, dtype={latent_model_input.dtype}") |
| | print(f"[DEBUG] runtime prompt shape: {prompt_embeds_tensor.shape}, dtype={prompt_embeds_tensor.dtype}") |
| | print(f"[DEBUG] runtime control_context shape: {control_context.shape}, dtype={control_context.dtype}") |
| |
|
| | if onnx_transformer_enabled: |
| | onnx_latent = latent_model_input.to(dtype=onnx_export_dtype) |
| | onnx_prompt = prompt_embeds_tensor.to(dtype=onnx_export_dtype) |
| | onnx_control = control_context.to(dtype=onnx_export_dtype) |
| | if apply_cfg: |
| | onnx_control = onnx_control.repeat(2, 1, 1, 1, 1) |
| |
|
| | if not _onnx_shape_compatible(controlnet_onnx_path, onnx_providers, onnx_latent.shape, onnx_prompt.shape, verbose=True): |
| | print(f"ONNX controlnet 输入尺寸与当前推理不匹配,回退到 torch。模型期望与当前 latent/prompt 尺寸不同。") |
| | onnx_transformer_enabled = False |
| | if onnx_transformer_enabled and not _onnx_shape_compatible(transformer_body_onnx_path, onnx_providers, onnx_latent.shape, onnx_prompt.shape, verbose=True): |
| | print(f"ONNX transformer body 输入尺寸与当前推理不匹配,回退到 torch。模型期望与当前 latent/prompt 尺寸不同。") |
| |
|
| | if onnx_transformer_enabled: |
| |
|
| | control_inputs = { |
| | "latent_model_input": onnx_latent, |
| | "timestep": timestep_model_input.to(dtype=torch.float32), |
| | "prompt_embeds": onnx_prompt, |
| | "control_context": onnx_control, |
| | } |
| | if use_controlnet: |
| | control_hints = ort_inference(controlnet_onnx_path, control_inputs, providers=onnx_providers)[0] |
| | control_hints_tensor = torch.from_numpy(control_hints).to(device=device, dtype=onnx_export_dtype) |
| | else: |
| | control_hints_tensor = _make_zero_control_hints(transformer_body_onnx_path, onnx_providers, device, onnx_export_dtype) |
| |
|
| | body_inputs = { |
| | "latent_model_input": onnx_latent, |
| | "timestep": timestep_model_input.to(dtype=torch.float32), |
| | "prompt_embeds": onnx_prompt, |
| | "control_hints": control_hints_tensor, |
| | "control_context_scale": torch.zeros_like(control_context_scale_tensor).to(dtype=torch.float32) if not use_controlnet else control_context_scale_tensor.to(dtype=torch.float32), |
| | } |
| | model_out = ort_inference(transformer_body_onnx_path, body_inputs, providers=onnx_providers)[0] |
| | model_out_tensor = torch.from_numpy(model_out).to(device=device, dtype=onnx_export_dtype) |
| | model_out_list = list(model_out_tensor) |
| | else: |
| | latent_model_input_list = [t.to(transformer.dtype) for t in latent_model_input.unbind(dim=0)] |
| | control_context_input = control_context.repeat(2, 1, 1, 1, 1) if apply_cfg else control_context |
| | if not use_controlnet: |
| | control_context_input = torch.zeros_like(control_context_input) |
| | control_context_scale_zero = torch.zeros_like(control_context_scale_tensor) |
| | else: |
| | control_context_scale_zero = control_context_scale_tensor |
| |
|
| | control_context_input = control_context_input.to(dtype=transformer.dtype) |
| | control_context_scale_zero = control_context_scale_zero.to(dtype=transformer.dtype) |
| |
|
| | |
| | timestep_model_input = timestep_model_input.to(dtype=transformer.dtype) |
| |
|
| | |
| | if device.type == "cuda": |
| | with torch.autocast(device_type="cuda", dtype=transformer.dtype): |
| | model_out_list = transformer( |
| | latent_model_input_list, |
| | timestep_model_input, |
| | prompt_embeds_model_input, |
| | control_context=control_context_input, |
| | control_context_scale=control_context_scale_zero, |
| | )[0] |
| | else: |
| | model_out_list = transformer( |
| | latent_model_input_list, |
| | timestep_model_input, |
| | prompt_embeds_model_input, |
| | control_context=control_context_input, |
| | control_context_scale=control_context_scale_zero, |
| | )[0] |
| |
|
| | if not _printed_onnx_debug: |
| | _printed_onnx_debug = True |
| |
|
| | if apply_cfg: |
| | |
| | pos_out = model_out_list[:actual_batch_size] |
| | neg_out = model_out_list[actual_batch_size:] |
| |
|
| | noise_pred = [] |
| | for j in range(actual_batch_size): |
| | pos = pos_out[j].float() |
| | neg = neg_out[j].float() |
| |
|
| | pred = pos + current_guidance_scale * (pos - neg) |
| |
|
| | |
| | if _cfg_normalization and float(_cfg_normalization) > 0.0: |
| | ori_pos_norm = torch.linalg.vector_norm(pos) |
| | new_pos_norm = torch.linalg.vector_norm(pred) |
| | max_new_norm = ori_pos_norm * float(_cfg_normalization) |
| | if new_pos_norm > max_new_norm: |
| | pred = pred * (max_new_norm / new_pos_norm) |
| |
|
| | noise_pred.append(pred) |
| |
|
| | noise_pred = torch.stack(noise_pred, dim=0) |
| | else: |
| | noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) |
| |
|
| | noise_pred = noise_pred.squeeze(2) |
| | noise_pred = -noise_pred |
| |
|
| | |
| | latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] |
| | assert latents.dtype == torch.float32 |
| |
|
| | if callback_on_step_end is not None: |
| | callback_kwargs = {} |
| | for k in callback_on_step_end_tensor_inputs: |
| | callback_kwargs[k] = locals()[k] |
| | callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
| |
|
| | latents = callback_outputs.pop("latents", latents) |
| | prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| | negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
| |
|
| | |
| | |
| | |
| |
|
| | |
| |
|
| | """ |
| | # 原始数据 (vae onnx 版本) |
| | (Pdb) latents.shape |
| | torch.Size([1, 16, 216, 124]) |
| | (Pdb) latents[0, 0, 100, 100:105] |
| | tensor([ 0.8545, 1.0117, 0.7908, -0.7002, 0.3965], device='cuda:0') |
| | (Pdb) |
| | |
| | # 当前数据 launcher.py |
| | (Pdb) latents.shape |
| | torch.Size([1, 16, 216, 124]) |
| | (Pdb) latents[0, 0, 100, 100:105] |
| | tensor([ 0.2899, 0.7049, 0.4407, -1.2531, -0.0161], device='cuda:2') |
| | (Pdb) |
| | |
| | # 对齐 embedding 后的结果(torch版本) |
| | (Pdb) latents.shape |
| | torch.Size([1, 16, 216, 124]) |
| | (Pdb) latents[0, 0, 100, 100:105] |
| | tensor([ 0.8763, 1.0138, 0.8031, -0.6851, 0.3751], device='cuda:2') # cuda:0 结果相同 |
| | (Pdb) |
| | |
| | # 修改为 onnx 版本 |
| | (Pdb) latents[0, 0, 100, 100:105] |
| | tensor([ 0.8545, 1.0117, 0.7908, -0.7002, 0.3965], device='cuda:0') |
| | |
| | # 以上结果对齐 |
| | """ |
| |
|
| | output_type = "pil" |
| | if output_type == "latent": |
| | image = latents |
| |
|
| | else: |
| | latents = latents.to(vae.dtype) |
| | latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
| |
|
| | if os.path.exists(vae_decoder_onnx_path): |
| | image = ort_inference( |
| | vae_decoder_onnx_path, |
| | {"latents": latents}, |
| | )[0] |
| | image = torch.from_numpy(image).to(device=device, dtype=vae.dtype) |
| | else: |
| | with module_to_device(vae, device): |
| | image = vae.decode(latents, return_dict=False)[0] |
| | """ |
| | (Pdb) latents[0, 0, 100, 100:105] |
| | tensor([ 2.4844, 2.9062, 2.2969, -1.8203, 1.2188], device='cuda:0', |
| | dtype=torch.bfloat16) |
| | (Pdb) image[0, 0, 100, 100:105] |
| | tensor([0.3906, 0.3848, 0.3809, 0.3809, 0.3848], device='cuda:0', |
| | dtype=torch.bfloat16) |
| | """ |
| | |
| | |
| | |
| | |
| | |
| | image = image_processor.postprocess(image, output_type=output_type) |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | sample = image |
| |
|
| | def save_results(): |
| | if not os.path.exists(save_path): |
| | os.makedirs(save_path, exist_ok=True) |
| |
|
| | index = len([path for path in os.listdir(save_path)]) + 1 |
| | prefix = str(index).zfill(8) |
| | video_path = os.path.join(save_path, prefix + ".png") |
| | image = sample[0] |
| | image.save(video_path) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | save_results() |
| | logger.info(f"Saved image to {save_path}") |