File size: 19,059 Bytes
be9fa39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 |
# Copyright 2025 Qwen-Image Team and The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Tuple
import torch
from ...models import QwenImageMultiControlNetModel
from ..modular_pipeline import ModularPipelineBlocks, PipelineState
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier
def repeat_tensor_to_batch_size(
input_name: str,
input_tensor: torch.Tensor,
batch_size: int,
num_images_per_prompt: int = 1,
) -> torch.Tensor:
"""Repeat tensor elements to match the final batch size.
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt)
by repeating each element along dimension 0.
The input tensor must have batch size 1 or batch_size. The function will:
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times
- If batch size equals batch_size: repeat each element num_images_per_prompt times
Args:
input_name (str): Name of the input tensor (used for error messages)
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size.
batch_size (int): The base batch size (number of prompts)
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1.
Returns:
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt)
Raises:
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size
Examples:
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor,
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape:
[4, 3]
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image",
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]])
- shape: [4, 3]
"""
# make sure input is a tensor
if not isinstance(input_tensor, torch.Tensor):
raise ValueError(f"`{input_name}` must be a tensor")
# make sure input tensor e.g. image_latents has batch size 1 or batch_size same as prompts
if input_tensor.shape[0] == 1:
repeat_by = batch_size * num_images_per_prompt
elif input_tensor.shape[0] == batch_size:
repeat_by = num_images_per_prompt
else:
raise ValueError(
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}"
)
# expand the tensor to match the batch_size * num_images_per_prompt
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0)
return input_tensor
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]:
"""Calculate image dimensions from latent tensor dimensions.
This function converts latent space dimensions to image space dimensions by multiplying the latent height and width
by the VAE scale factor.
Args:
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions.
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width]
vae_scale_factor (int): The scale factor used by the VAE to compress images.
Typically 8 for most VAEs (image is 8x larger than latents in each dimension)
Returns:
Tuple[int, int]: The calculated image dimensions as (height, width)
Raises:
ValueError: If latents tensor doesn't have 4 or 5 dimensions
"""
# make sure the latents are not packed
if latents.ndim != 4 and latents.ndim != 5:
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}")
latent_height, latent_width = latents.shape[-2:]
height = latent_height * vae_scale_factor
width = latent_width * vae_scale_factor
return height, width
class QwenImageTextInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
summary_section = (
"Text input processing step that standardizes text embeddings for the pipeline.\n"
"This step:\n"
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n"
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)"
)
# Placement guidance
placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps."
return summary_section + placement_section
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"),
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"),
]
@property
def intermediate_outputs(self) -> List[str]:
return [
OutputParam(
"batch_size",
type_hint=int,
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt",
),
OutputParam(
"dtype",
type_hint=torch.dtype,
description="Data type of model tensor inputs (determined by `prompt_embeds`)",
),
]
@staticmethod
def check_inputs(
prompt_embeds,
prompt_embeds_mask,
negative_prompt_embeds,
negative_prompt_embeds_mask,
):
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None:
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None")
if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None:
raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`")
if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]:
raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]:
raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`")
elif (
negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]
):
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`")
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(
prompt_embeds=block_state.prompt_embeds,
prompt_embeds_mask=block_state.prompt_embeds_mask,
negative_prompt_embeds=block_state.negative_prompt_embeds,
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask,
)
block_state.batch_size = block_state.prompt_embeds.shape[0]
block_state.dtype = block_state.prompt_embeds.dtype
_, seq_len, _ = block_state.prompt_embeds.shape
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds = block_state.prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1)
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len
)
if block_state.negative_prompt_embeds is not None:
_, seq_len, _ = block_state.negative_prompt_embeds.shape
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1
)
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat(
1, block_state.num_images_per_prompt, 1
)
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view(
block_state.batch_size * block_state.num_images_per_prompt, seq_len
)
self.set_block_state(state, block_state)
return components, state
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"
def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size
This is a dynamic block that allows you to configure which inputs to process.
Args:
image_latent_inputs (List[str], optional): Names of image latent tensors to process.
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"]
additional_batch_inputs (List[str], optional):
Names of additional conditional input tensors to expand batch size. These tensors will only have their
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings.
Defaults to []. Examples: ["processed_mask_image"]
Examples:
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep()
# Configure to process multiple image latent inputs
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"])
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep(
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"]
)
"""
if not isinstance(image_latent_inputs, list):
image_latent_inputs = [image_latent_inputs]
if not isinstance(additional_batch_inputs, list):
additional_batch_inputs = [additional_batch_inputs]
self._image_latent_inputs = image_latent_inputs
self._additional_batch_inputs = additional_batch_inputs
super().__init__()
@property
def description(self) -> str:
# Functionality section
summary_section = (
"Input processing step that:\n"
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n"
" 2. For additional batch inputs: Expands batch dimensions to match final batch size"
)
# Inputs info
inputs_info = ""
if self._image_latent_inputs or self._additional_batch_inputs:
inputs_info = "\n\nConfigured inputs:"
if self._image_latent_inputs:
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}"
if self._additional_batch_inputs:
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}"
# Placement guidance
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step."
return summary_section + inputs_info + placement_section
@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="batch_size", required=True),
InputParam(name="height"),
InputParam(name="width"),
]
# Add image latent inputs
for image_latent_input_name in self._image_latent_inputs:
inputs.append(InputParam(name=image_latent_input_name))
# Add additional batch inputs
for input_name in self._additional_batch_inputs:
inputs.append(InputParam(name=input_name))
return inputs
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"),
]
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. Patchify the image latent tensor
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor)
# 3. Expand batch size
image_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=image_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, image_latent_input_name, image_latent_tensor)
# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue
# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
setattr(block_state, input_name, input_tensor)
self.set_block_state(state, block_state)
return components, state
class QwenImageControlNetInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"
@property
def description(self) -> str:
return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps."
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(name="control_image_latents", required=True),
InputParam(name="batch_size", required=True),
InputParam(name="num_images_per_prompt", default=1),
InputParam(name="height"),
InputParam(name="width"),
]
@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
if isinstance(components.controlnet, QwenImageMultiControlNetModel):
control_image_latents = []
# loop through each control_image_latents
for i, control_image_latents_ in enumerate(block_state.control_image_latents):
# 1. update height/width if not provided
height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. pack
control_image_latents_ = components.pachifier.pack_latents(control_image_latents_)
# 3. repeat to match the batch size
control_image_latents_ = repeat_tensor_to_batch_size(
input_name=f"control_image_latents[{i}]",
input_tensor=control_image_latents_,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
control_image_latents.append(control_image_latents_)
block_state.control_image_latents = control_image_latents
else:
# 1. update height/width if not provided
height, width = calculate_dimension_from_latents(
block_state.control_image_latents, components.vae_scale_factor
)
block_state.height = block_state.height or height
block_state.width = block_state.width or width
# 2. pack
block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents)
# 3. repeat to match the batch size
block_state.control_image_latents = repeat_tensor_to_batch_size(
input_name="control_image_latents",
input_tensor=block_state.control_image_latents,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
block_state.control_image_latents = block_state.control_image_latents
self.set_block_state(state, block_state)
return components, state
|