| import os |
| import sys |
| import numpy as np |
| import torch |
|
|
| from omegaconf import OmegaConf |
| from PIL import Image |
|
|
| current_file_path = os.path.abspath(__file__) |
| project_roots = [os.path.dirname(current_file_path), os.path.dirname(os.path.dirname(current_file_path)), os.path.dirname(os.path.dirname(os.path.dirname(current_file_path)))] |
| for project_root in project_roots: |
| sys.path.insert(0, project_root) if project_root not in sys.path else None |
|
|
| from diffusers import FlowMatchEulerDiscreteScheduler |
| from videox_fun.utils.fm_solvers import FlowDPMSolverMultistepScheduler |
| from videox_fun.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler |
|
|
| from videox_fun.models import (AutoencoderKL, AutoTokenizer, |
| Qwen3ForCausalLM, ZImageControlTransformer2DModel) |
| from typing import List, Optional, Union |
| from diffusers.utils.torch_utils import randn_tensor |
| from videox_fun.utils.utils import get_image_latent |
|
|
|
|
| |
| config_path = "config/z_image/z_image_control.yaml" |
| model_name = "models/Diffusion_Transformer/Z-Image-Turbo/" |
|
|
| |
| |
| weight_dtype = torch.bfloat16 |
| control_image = "asset/pose.jpg" |
| control_context_scale = 0.75 |
|
|
| device = torch.device('cuda:2') if torch.cuda.is_available() else torch.device('cpu') |
|
|
| |
| tokenizer = AutoTokenizer.from_pretrained( |
| model_name, subfolder="tokenizer" |
| ) |
| text_encoder = Qwen3ForCausalLM.from_pretrained( |
| model_name, subfolder="text_encoder", torch_dtype=weight_dtype, |
| low_cpu_mem_usage=True, |
| ).to(device) |
|
|
| def _encode_prompt( |
| prompt: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| max_sequence_length: int = 512, |
| ) -> List[torch.FloatTensor]: |
| |
|
|
| if prompt_embeds is not None: |
| return prompt_embeds |
|
|
| if isinstance(prompt, str): |
| prompt = [prompt] |
|
|
| for i, prompt_item in enumerate(prompt): |
| messages = [ |
| {"role": "user", "content": prompt_item}, |
| ] |
| prompt_item = tokenizer.apply_chat_template( |
| messages, |
| tokenize=False, |
| add_generation_prompt=True, |
| enable_thinking=True, |
| ) |
| prompt[i] = prompt_item |
|
|
| text_inputs = tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids.to(device) |
| prompt_masks = text_inputs.attention_mask.to(device).bool() |
|
|
| prompt_embeds = text_encoder( |
| input_ids=text_input_ids, |
| attention_mask=prompt_masks, |
| output_hidden_states=True, |
| ).hidden_states[-2] |
|
|
| embeddings_list = [] |
|
|
| for i in range(len(prompt_embeds)): |
| embeddings_list.append(prompt_embeds[i][prompt_masks[i]]) |
| |
|
|
| return embeddings_list |
|
|
|
|
| def encode_prompt( |
| prompt: Union[str, List[str]], |
| device: Optional[torch.device] = None, |
| do_classifier_free_guidance: bool = True, |
| negative_prompt: Optional[Union[str, List[str]]] = None, |
| prompt_embeds: Optional[List[torch.FloatTensor]] = None, |
| negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
| max_sequence_length: int = 512, |
| ): |
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| prompt_embeds = _encode_prompt( |
| prompt=prompt, |
| device=device, |
| prompt_embeds=prompt_embeds, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| if do_classifier_free_guidance: |
| if negative_prompt is None: |
| negative_prompt = ["" for _ in prompt] |
| else: |
| negative_prompt = [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt |
| assert len(prompt) == len(negative_prompt) |
| negative_prompt_embeds = _encode_prompt( |
| prompt=negative_prompt, |
| device=device, |
| prompt_embeds=negative_prompt_embeds, |
| max_sequence_length=max_sequence_length, |
| ) |
| else: |
| negative_prompt_embeds = [] |
| return prompt_embeds, negative_prompt_embeds |
|
|
|
|
| def prepare_latents( |
| batch_size, |
| num_channels_latents, |
| height, |
| width, |
| dtype, |
| device, |
| generator, |
| latents=None, |
| ): |
| height = 2 * (int(height) // (vae_scale_factor * 2)) |
| width = 2 * (int(width) // (vae_scale_factor * 2)) |
|
|
| shape = (batch_size, num_channels_latents, height, width) |
|
|
| if latents is None: |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) |
| else: |
| if latents.shape != shape: |
| raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") |
| latents = latents.to(device) |
| return latents |
|
|
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| GPU_memory_mode = "model_cpu_offload" |
| |
| |
| |
| |
| ulysses_degree = 1 |
| ring_degree = 1 |
| |
| fsdp_dit = False |
| fsdp_text_encoder = False |
| |
| |
| compile_dit = False |
|
|
|
|
| |
|
|
| |
| sampler_name = "Flow" |
|
|
| |
| transformer_path = "models/Personalized_Model/Z-Image-Turbo-Fun-Controlnet-Union.safetensors" |
| vae_path = None |
| lora_path = None |
|
|
| |
| sample_size = [1728, 992] |
|
|
|
|
| |
| |
| prompt = "一位年轻女子站在阳光明媚的海岸线上, 白裙在轻拂的海风中微微飘动.她拥有一头鲜艳的紫色长发, 在风中轻盈舞动, 发间系着一个精致的黑色蝴蝶结, 与身后柔和的蔚蓝天空形成鲜明对比.她面容清秀, 眉目精致, 透着一股甜美的青春气息;神情柔和, 略带羞涩, 目光静静地凝望着远方的地平线, 双手自然交叠于身前, 仿佛沉浸在思绪之中.在她身后, 是辽阔无垠、波光粼粼的大海, 阳光洒在海面上, 映出温暖的金色光晕." |
| |
| negative_prompt = " " |
| guidance_scale = 0.00 |
| seed = 43 |
| num_inference_steps = 9 |
| lora_weight = 0.55 |
| save_path = "samples/z-image-t2i-control" |
|
|
| config = OmegaConf.load(config_path) |
|
|
| transformer = ZImageControlTransformer2DModel.from_pretrained( |
| model_name, |
| subfolder="transformer", |
| low_cpu_mem_usage=True, |
| torch_dtype=weight_dtype, |
| transformer_additional_kwargs=OmegaConf.to_container(config['transformer_additional_kwargs']), |
| ).to(weight_dtype).to(device) |
|
|
| if transformer_path is not None: |
| print(f"From checkpoint: {transformer_path}") |
| if transformer_path.endswith("safetensors"): |
| from safetensors.torch import load_file, safe_open |
| state_dict = load_file(transformer_path) |
| else: |
| state_dict = torch.load(transformer_path, map_location="cpu") |
| state_dict = state_dict["state_dict"] if "state_dict" in state_dict else state_dict |
|
|
| m, u = transformer.load_state_dict(state_dict, strict=False) |
| print(f"missing keys: {len(m)}, unexpected keys: {len(u)}") |
|
|
|
|
| |
| Chosen_Scheduler = { |
| "Flow": FlowMatchEulerDiscreteScheduler, |
| "Flow_Unipc": FlowUniPCMultistepScheduler, |
| "Flow_DPM++": FlowDPMSolverMultistepScheduler, |
| }[sampler_name] |
|
|
| scheduler = Chosen_Scheduler.from_pretrained( |
| model_name, |
| subfolder="scheduler" |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| generator = torch.Generator(device=device).manual_seed(seed) |
|
|
| |
| |
|
|
| |
| if control_image is not None: |
| control_image = get_image_latent(control_image, sample_size=sample_size)[:, :, 0] |
|
|
| print(control_image.shape) |
|
|
|
|
| height, width = sample_size |
|
|
| vae_scale_factor = 8 |
| vae_scale = vae_scale_factor * 2 |
|
|
| if height % vae_scale != 0: |
| raise ValueError( |
| f"Height must be divisible by {vae_scale} (got {height}). " |
| f"Please adjust the height to a multiple of {vae_scale}." |
| ) |
| if width % vae_scale != 0: |
| raise ValueError( |
| f"Width must be divisible by {vae_scale} (got {width}). " |
| f"Please adjust the width to a multiple of {vae_scale}." |
| ) |
|
|
| _guidance_scale = guidance_scale = 0.0 |
| _joint_attention_kwargs = joint_attention_kwargs = None |
| _interrupt = False |
| _cfg_normalization = cfg_normalization = False |
| _cfg_truncation = cfg_truncation = 1.0 |
|
|
|
|
| |
| prompt_embeds = None |
| if prompt is not None and isinstance(prompt, str): |
| batch_size = 1 |
| elif prompt is not None and isinstance(prompt, list): |
| batch_size = len(prompt) |
| else: |
| batch_size = len(prompt_embeds) |
|
|
| batch_size = 1 |
|
|
| weight_dtype = text_encoder.dtype |
| num_channels_latents = 16 |
| vae_config_shift_factor = 0.1159 |
| vae_config_scaling_factor = 0.3611 |
| inpaint_latent = None |
|
|
| from diffusers.image_processor import VaeImageProcessor |
| image_processor = VaeImageProcessor(vae_scale_factor=vae_scale_factor * 2) |
|
|
| |
| vae = AutoencoderKL.from_pretrained( |
| model_name, |
| subfolder="vae" |
| ).to(weight_dtype).to(device) |
|
|
| if control_image is not None: |
| control_image = image_processor.preprocess(control_image, height=height, width=width) |
| control_image = control_image.to(dtype=weight_dtype, device=device) |
| control_latents = vae.encode(control_image)[0].mode() |
| |
| |
| |
| |
| |
| control_latents = (control_latents - vae_config_shift_factor) * vae_config_scaling_factor |
| else: |
| control_latents = torch.zeros_like(inpaint_latent) |
|
|
| control_context = control_latents.unsqueeze(2) |
|
|
|
|
| do_classifier_free_guidance = False |
| negative_prompt_embeds = None |
| max_sequence_length = 512 |
|
|
| prompt_embeds, negative_prompt_embeds = encode_prompt( |
| prompt=prompt, |
| negative_prompt=negative_prompt, |
| do_classifier_free_guidance=do_classifier_free_guidance, |
| prompt_embeds=prompt_embeds, |
| negative_prompt_embeds=negative_prompt_embeds, |
| device=device, |
| max_sequence_length=max_sequence_length, |
| ) |
|
|
| num_images_per_prompt = 1 |
| latents = None |
|
|
| """ |
| (Pdb) latents[0, 0, 100:105, 100] |
| tensor([-0.9203, 1.3958, 0.8130, -0.5280, -1.9788], device='cuda:0') |
| """ |
|
|
| |
| latents = prepare_latents( |
| batch_size * num_images_per_prompt, |
| num_channels_latents, |
| height, |
| width, |
| torch.float32, |
| device, |
| generator, |
| latents, |
| ) |
| print(latents.shape) |
| |
|
|
| |
| if num_images_per_prompt > 1: |
| prompt_embeds = [pe for pe in prompt_embeds for _ in range(num_images_per_prompt)] |
| if do_classifier_free_guidance and negative_prompt_embeds: |
| negative_prompt_embeds = [npe for npe in negative_prompt_embeds for _ in range(num_images_per_prompt)] |
|
|
| actual_batch_size = batch_size * num_images_per_prompt |
| image_seq_len = (latents.shape[2] // 2) * (latents.shape[3] // 2) |
|
|
|
|
| |
| def calculate_shift( |
| image_seq_len, |
| base_seq_len: int = 256, |
| max_seq_len: int = 4096, |
| base_shift: float = 0.5, |
| max_shift: float = 1.15, |
| ): |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) |
| b = base_shift - m * base_seq_len |
| mu = image_seq_len * m + b |
| return mu |
|
|
| |
| mu = calculate_shift( |
| image_seq_len, |
| scheduler.config.get("base_image_seq_len", 256), |
| scheduler.config.get("max_image_seq_len", 4096), |
| scheduler.config.get("base_shift", 0.5), |
| scheduler.config.get("max_shift", 1.15), |
| ) |
| scheduler.sigma_min = 0.0 |
| scheduler_kwargs = {"mu": mu} |
|
|
|
|
| import inspect |
| |
| def retrieve_timesteps( |
| scheduler, |
| num_inference_steps: Optional[int] = None, |
| device: Optional[Union[str, torch.device]] = None, |
| timesteps: Optional[List[int]] = None, |
| sigmas: Optional[List[float]] = None, |
| **kwargs, |
| ): |
| r""" |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
| |
| Args: |
| scheduler (`SchedulerMixin`): |
| The scheduler to get timesteps from. |
| num_inference_steps (`int`): |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` |
| must be `None`. |
| device (`str` or `torch.device`, *optional*): |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
| timesteps (`List[int]`, *optional*): |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, |
| `num_inference_steps` and `sigmas` must be `None`. |
| sigmas (`List[float]`, *optional*): |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, |
| `num_inference_steps` and `timesteps` must be `None`. |
| |
| Returns: |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
| second element is the number of inference steps. |
| """ |
| if timesteps is not None and sigmas is not None: |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") |
| if timesteps is not None: |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accepts_timesteps: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" timestep schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| elif sigmas is not None: |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) |
| if not accept_sigmas: |
| raise ValueError( |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
| f" sigmas schedules. Please check whether you are using the correct scheduler." |
| ) |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| num_inference_steps = len(timesteps) |
| else: |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
| timesteps = scheduler.timesteps |
| return timesteps, num_inference_steps |
|
|
|
|
| sigmas = None |
| timesteps, num_inference_steps = retrieve_timesteps( |
| scheduler, |
| num_inference_steps, |
| device, |
| sigmas=sigmas, |
| **scheduler_kwargs, |
| ) |
| num_warmup_steps = max(len(timesteps) - num_inference_steps * scheduler.order, 0) |
| """ |
| (Pdb) timesteps |
| tensor([1000.0000, 954.5454, 900.0000, 833.3333, 750.0000, 642.8571, |
| 500.0000, 300.0000, 0.0000], device='cuda:0') |
| """ |
| _num_timesteps = len(timesteps) |
| |
| |
|
|
| callback_on_step_end = None |
| callback_on_step_end_tensor_inputs = ['latents'] |
|
|
| |
| |
| for i, t in enumerate(timesteps): |
|
|
| |
| timestep = t.expand(latents.shape[0]) |
| timestep = (1000 - timestep) / 1000 |
| |
| t_norm = timestep[0].item() |
|
|
| |
| current_guidance_scale = guidance_scale |
| if ( |
| do_classifier_free_guidance |
| and _cfg_truncation is not None |
| and float(_cfg_truncation) <= 1 |
| ): |
| if t_norm > _cfg_truncation: |
| current_guidance_scale = 0.0 |
|
|
| |
| apply_cfg = do_classifier_free_guidance and current_guidance_scale > 0 |
|
|
| if apply_cfg: |
| latents_typed = latents.to(transformer.dtype) |
| latent_model_input = latents_typed.repeat(2, 1, 1, 1) |
| prompt_embeds_model_input = prompt_embeds + negative_prompt_embeds |
| timestep_model_input = timestep.repeat(2) |
| else: |
| latent_model_input = latents.to(transformer.dtype) |
| prompt_embeds_model_input = prompt_embeds |
| timestep_model_input = timestep |
|
|
| latent_model_input = latent_model_input.unsqueeze(2) |
| latent_model_input_list = list(latent_model_input.unbind(dim=0)) |
|
|
| import pdb; pdb.set_trace() |
| model_out_list = transformer( |
| latent_model_input_list, |
| timestep_model_input, |
| prompt_embeds_model_input, |
| control_context=control_context, |
| control_context_scale=control_context_scale, |
| )[0] |
|
|
| if apply_cfg: |
| |
| pos_out = model_out_list[:actual_batch_size] |
| neg_out = model_out_list[actual_batch_size:] |
|
|
| noise_pred = [] |
| for j in range(actual_batch_size): |
| pos = pos_out[j].float() |
| neg = neg_out[j].float() |
|
|
| pred = pos + current_guidance_scale * (pos - neg) |
|
|
| |
| if _cfg_normalization and float(_cfg_normalization) > 0.0: |
| ori_pos_norm = torch.linalg.vector_norm(pos) |
| new_pos_norm = torch.linalg.vector_norm(pred) |
| max_new_norm = ori_pos_norm * float(_cfg_normalization) |
| if new_pos_norm > max_new_norm: |
| pred = pred * (max_new_norm / new_pos_norm) |
|
|
| noise_pred.append(pred) |
|
|
| noise_pred = torch.stack(noise_pred, dim=0) |
| else: |
| noise_pred = torch.stack([t.float() for t in model_out_list], dim=0) |
|
|
| noise_pred = noise_pred.squeeze(2) |
| noise_pred = -noise_pred |
|
|
| |
| latents = scheduler.step(noise_pred.to(torch.float32), t, latents, return_dict=False)[0] |
| assert latents.dtype == torch.float32 |
|
|
| if callback_on_step_end is not None: |
| callback_kwargs = {} |
| for k in callback_on_step_end_tensor_inputs: |
| callback_kwargs[k] = locals()[k] |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) |
|
|
| latents = callback_outputs.pop("latents", latents) |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) |
|
|
| |
| |
| |
|
|
| output_type = "pil" |
| if output_type == "latent": |
| image = latents |
|
|
| else: |
| latents = latents.to(vae.dtype) |
| latents = (latents / vae.config.scaling_factor) + vae.config.shift_factor |
|
|
| image = vae.decode(latents, return_dict=False)[0] |
| """ |
| (Pdb) image[0, 0, 100, 100:105] |
| tensor([0.3906, 0.3848, 0.3809, 0.3809, 0.3848], device='cuda:0', |
| dtype=torch.bfloat16) |
| """ |
| import pdb; pdb.set_trace() |
| |
| |
| |
| |
| |
| image = image_processor.postprocess(image, output_type=output_type) |
|
|
| |
| |
|
|
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |