salmankhanpm's picture
Add files using upload-large-folder tool
4f4376a verified
# Copyright 2025 Alibaba Z-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any
import torch
from ...configuration_utils import FrozenDict
from ...guiders import ClassifierFreeGuidance
from ...models import ZImageTransformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
from ...utils import logging
from ..modular_pipeline import (
BlockState,
LoopSequentialPipelineBlocks,
ModularPipelineBlocks,
PipelineState,
)
from ..modular_pipeline_utils import ComponentSpec, InputParam
from .modular_pipeline import ZImageModularPipeline
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class ZImageLoopBeforeDenoiser(ModularPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return (
"step within the denoising loop that prepares the latent input for the denoiser. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> list[InputParam]:
return [
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
),
InputParam(
"dtype",
required=True,
type_hint=torch.dtype,
description="The dtype of the model inputs. Can be generated in input step.",
),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
latents = block_state.latents.unsqueeze(2).to(
block_state.dtype
) # [batch_size, num_channels, 1, height, width]
block_state.latent_model_input = list(latents.unbind(dim=0)) # list of [num_channels, 1, height, width]
timestep = t.expand(latents.shape[0]).to(block_state.dtype)
timestep = (1000 - timestep) / 1000
block_state.timestep = timestep
return components, block_state
class ZImageLoopDenoiser(ModularPipelineBlocks):
model_name = "z-image"
def __init__(
self,
guider_input_fields: dict[str, Any] = {"cap_feats": ("prompt_embeds", "negative_prompt_embeds")},
):
"""Initialize a denoiser block that calls the denoiser model. This block is used in Z-Image.
Args:
guider_input_fields: A dictionary that maps each argument expected by the denoiser model
(for example, "encoder_hidden_states") to data stored on 'block_state'. The value can be either:
- A tuple of strings. For instance, {"encoder_hidden_states": ("prompt_embeds",
"negative_prompt_embeds")} tells the guider to read `block_state.prompt_embeds` and
`block_state.negative_prompt_embeds` and pass them as the conditional and unconditional batches of
'encoder_hidden_states'.
- A string. For example, {"encoder_hidden_image": "image_embeds"} makes the guider forward
`block_state.image_embeds` for both conditional and unconditional batches.
"""
if not isinstance(guider_input_fields, dict):
raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}")
self._guider_input_fields = guider_input_fields
super().__init__()
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec(
"guider",
ClassifierFreeGuidance,
config=FrozenDict({"guidance_scale": 5.0, "enabled": False}),
default_creation_method="from_config",
),
ComponentSpec("transformer", ZImageTransformer2DModel),
]
@property
def description(self) -> str:
return (
"Step within the denoising loop that denoise the latents with guidance. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@property
def inputs(self) -> list[tuple[str, Any]]:
inputs = [
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
kwargs_type="denoiser_input_fields",
description="The conditional model inputs for the denoiser: e.g. prompt_embeds, negative_prompt_embeds, etc.",
),
]
guider_input_names = []
uncond_guider_input_names = []
for value in self._guider_input_fields.values():
if isinstance(value, tuple):
guider_input_names.append(value[0])
uncond_guider_input_names.append(value[1])
else:
guider_input_names.append(value)
for name in guider_input_names:
inputs.append(InputParam(name=name, required=True))
for name in uncond_guider_input_names:
inputs.append(InputParam(name=name))
return inputs
@torch.no_grad()
def __call__(
self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
) -> PipelineState:
components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
# The guider splits model inputs into separate batches for conditional/unconditional predictions.
# For CFG with guider_inputs = {"encoder_hidden_states": (prompt_embeds, negative_prompt_embeds)}:
# you will get a guider_state with two batches:
# guider_state = [
# {"encoder_hidden_states": prompt_embeds, "__guidance_identifier__": "pred_cond"}, # conditional batch
# {"encoder_hidden_states": negative_prompt_embeds, "__guidance_identifier__": "pred_uncond"}, # unconditional batch
# ]
# Other guidance methods may return 1 batch (no guidance) or 3+ batches (e.g., PAG, APG).
guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields)
# run the denoiser for each guidance batch
for guider_state_batch in guider_state:
components.guider.prepare_models(components.transformer)
cond_kwargs = guider_state_batch.as_dict()
def _convert_dtype(v, dtype):
if isinstance(v, torch.Tensor):
return v.to(dtype)
elif isinstance(v, list):
return [_convert_dtype(t, dtype) for t in v]
return v
cond_kwargs = {
k: _convert_dtype(v, block_state.dtype)
for k, v in cond_kwargs.items()
if k in self._guider_input_fields.keys()
}
# Predict the noise residual
# store the noise_pred in guider_state_batch so that we can apply guidance across all batches
model_out_list = components.transformer(
x=block_state.latent_model_input,
t=block_state.timestep,
return_dict=False,
**cond_kwargs,
)[0]
noise_pred = torch.stack(model_out_list, dim=0).squeeze(2)
guider_state_batch.noise_pred = -noise_pred
components.guider.cleanup_models(components.transformer)
# Perform guidance
block_state.noise_pred = components.guider(guider_state)[0]
return components, block_state
class ZImageLoopAfterDenoiser(ModularPipelineBlocks):
model_name = "z-image"
@property
def expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def description(self) -> str:
return (
"step within the denoising loop that update the latents. "
"This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
"object (e.g. `ZImageDenoiseLoopWrapper`)"
)
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
# Perform scheduler step using the predicted output
latents_dtype = block_state.latents.dtype
block_state.latents = components.scheduler.step(
block_state.noise_pred.float(),
t,
block_state.latents.float(),
return_dict=False,
)[0]
if block_state.latents.dtype != latents_dtype:
block_state.latents = block_state.latents.to(latents_dtype)
return components, block_state
class ZImageDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
model_name = "z-image"
@property
def description(self) -> str:
return (
"Pipeline block that iteratively denoise the latents over `timesteps`. "
"The specific steps with each iteration can be customized with `sub_blocks` attributes"
)
@property
def loop_expected_components(self) -> list[ComponentSpec]:
return [
ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler),
]
@property
def loop_inputs(self) -> list[InputParam]:
return [
InputParam(
"timesteps",
required=True,
type_hint=torch.Tensor,
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
),
InputParam(
"num_inference_steps",
required=True,
type_hint=int,
description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
),
]
@torch.no_grad()
def __call__(self, components: ZImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
block_state.num_warmup_steps = max(
len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0
)
with self.progress_bar(total=block_state.num_inference_steps) as progress_bar:
for i, t in enumerate(block_state.timesteps):
components, block_state = self.loop_step(components, block_state, i=i, t=t)
if i == len(block_state.timesteps) - 1 or (
(i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0
):
progress_bar.update()
self.set_block_state(state, block_state)
return components, state
class ZImageDenoiseStep(ZImageDenoiseLoopWrapper):
block_classes = [
ZImageLoopBeforeDenoiser,
ZImageLoopDenoiser(
guider_input_fields={
"cap_feats": ("prompt_embeds", "negative_prompt_embeds"),
}
),
ZImageLoopAfterDenoiser,
]
block_names = ["before_denoiser", "denoiser", "after_denoiser"]
@property
def description(self) -> str:
return (
"Denoise step that iteratively denoise the latents. \n"
"Its loop logic is defined in `ZImageDenoiseLoopWrapper.__call__` method \n"
"At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n"
" - `ZImageLoopBeforeDenoiser`\n"
" - `ZImageLoopDenoiser`\n"
" - `ZImageLoopAfterDenoiser`\n"
"This block supports text-to-image and image-to-image tasks for Z-Image."
)