|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Tuple |
|
|
|
|
|
import torch |
|
|
|
|
|
from ...models import QwenImageMultiControlNetModel |
|
|
from ..modular_pipeline import ModularPipelineBlocks, PipelineState |
|
|
from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
|
|
from .modular_pipeline import QwenImageModularPipeline, QwenImagePachifier |
|
|
|
|
|
|
|
|
def repeat_tensor_to_batch_size( |
|
|
input_name: str, |
|
|
input_tensor: torch.Tensor, |
|
|
batch_size: int, |
|
|
num_images_per_prompt: int = 1, |
|
|
) -> torch.Tensor: |
|
|
"""Repeat tensor elements to match the final batch size. |
|
|
|
|
|
This function expands a tensor's batch dimension to match the final batch size (batch_size * num_images_per_prompt) |
|
|
by repeating each element along dimension 0. |
|
|
|
|
|
The input tensor must have batch size 1 or batch_size. The function will: |
|
|
- If batch size is 1: repeat each element (batch_size * num_images_per_prompt) times |
|
|
- If batch size equals batch_size: repeat each element num_images_per_prompt times |
|
|
|
|
|
Args: |
|
|
input_name (str): Name of the input tensor (used for error messages) |
|
|
input_tensor (torch.Tensor): The tensor to repeat. Must have batch size 1 or batch_size. |
|
|
batch_size (int): The base batch size (number of prompts) |
|
|
num_images_per_prompt (int, optional): Number of images to generate per prompt. Defaults to 1. |
|
|
|
|
|
Returns: |
|
|
torch.Tensor: The repeated tensor with final batch size (batch_size * num_images_per_prompt) |
|
|
|
|
|
Raises: |
|
|
ValueError: If input_tensor is not a torch.Tensor or has invalid batch size |
|
|
|
|
|
Examples: |
|
|
tensor = torch.tensor([[1, 2, 3]]) # shape: [1, 3] repeated = repeat_tensor_to_batch_size("image", tensor, |
|
|
batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [1, 2, 3], [1, 2, 3]]) - shape: |
|
|
[4, 3] |
|
|
|
|
|
tensor = torch.tensor([[1, 2, 3], [4, 5, 6]]) # shape: [2, 3] repeated = repeat_tensor_to_batch_size("image", |
|
|
tensor, batch_size=2, num_images_per_prompt=2) repeated # tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [4, 5, 6]]) |
|
|
- shape: [4, 3] |
|
|
""" |
|
|
|
|
|
if not isinstance(input_tensor, torch.Tensor): |
|
|
raise ValueError(f"`{input_name}` must be a tensor") |
|
|
|
|
|
|
|
|
if input_tensor.shape[0] == 1: |
|
|
repeat_by = batch_size * num_images_per_prompt |
|
|
elif input_tensor.shape[0] == batch_size: |
|
|
repeat_by = num_images_per_prompt |
|
|
else: |
|
|
raise ValueError( |
|
|
f"`{input_name}` must have have batch size 1 or {batch_size}, but got {input_tensor.shape[0]}" |
|
|
) |
|
|
|
|
|
|
|
|
input_tensor = input_tensor.repeat_interleave(repeat_by, dim=0) |
|
|
|
|
|
return input_tensor |
|
|
|
|
|
|
|
|
def calculate_dimension_from_latents(latents: torch.Tensor, vae_scale_factor: int) -> Tuple[int, int]: |
|
|
"""Calculate image dimensions from latent tensor dimensions. |
|
|
|
|
|
This function converts latent space dimensions to image space dimensions by multiplying the latent height and width |
|
|
by the VAE scale factor. |
|
|
|
|
|
Args: |
|
|
latents (torch.Tensor): The latent tensor. Must have 4 or 5 dimensions. |
|
|
Expected shapes: [batch, channels, height, width] or [batch, channels, frames, height, width] |
|
|
vae_scale_factor (int): The scale factor used by the VAE to compress images. |
|
|
Typically 8 for most VAEs (image is 8x larger than latents in each dimension) |
|
|
|
|
|
Returns: |
|
|
Tuple[int, int]: The calculated image dimensions as (height, width) |
|
|
|
|
|
Raises: |
|
|
ValueError: If latents tensor doesn't have 4 or 5 dimensions |
|
|
|
|
|
""" |
|
|
|
|
|
if latents.ndim != 4 and latents.ndim != 5: |
|
|
raise ValueError(f"unpacked latents must have 4 or 5 dimensions, but got {latents.ndim}") |
|
|
|
|
|
latent_height, latent_width = latents.shape[-2:] |
|
|
|
|
|
height = latent_height * vae_scale_factor |
|
|
width = latent_width * vae_scale_factor |
|
|
|
|
|
return height, width |
|
|
|
|
|
|
|
|
class QwenImageTextInputsStep(ModularPipelineBlocks): |
|
|
model_name = "qwenimage" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
summary_section = ( |
|
|
"Text input processing step that standardizes text embeddings for the pipeline.\n" |
|
|
"This step:\n" |
|
|
" 1. Determines `batch_size` and `dtype` based on `prompt_embeds`\n" |
|
|
" 2. Ensures all text embeddings have consistent batch sizes (batch_size * num_images_per_prompt)" |
|
|
) |
|
|
|
|
|
|
|
|
placement_section = "\n\nThis block should be placed after all encoder steps to process the text embeddings before they are used in subsequent pipeline steps." |
|
|
|
|
|
return summary_section + placement_section |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam(name="num_images_per_prompt", default=1), |
|
|
InputParam(name="prompt_embeds", required=True, kwargs_type="denoiser_input_fields"), |
|
|
InputParam(name="prompt_embeds_mask", required=True, kwargs_type="denoiser_input_fields"), |
|
|
InputParam(name="negative_prompt_embeds", kwargs_type="denoiser_input_fields"), |
|
|
InputParam(name="negative_prompt_embeds_mask", kwargs_type="denoiser_input_fields"), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[str]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"batch_size", |
|
|
type_hint=int, |
|
|
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt", |
|
|
), |
|
|
OutputParam( |
|
|
"dtype", |
|
|
type_hint=torch.dtype, |
|
|
description="Data type of model tensor inputs (determined by `prompt_embeds`)", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs( |
|
|
prompt_embeds, |
|
|
prompt_embeds_mask, |
|
|
negative_prompt_embeds, |
|
|
negative_prompt_embeds_mask, |
|
|
): |
|
|
if negative_prompt_embeds is not None and negative_prompt_embeds_mask is None: |
|
|
raise ValueError("`negative_prompt_embeds_mask` is required when `negative_prompt_embeds` is not None") |
|
|
|
|
|
if negative_prompt_embeds is None and negative_prompt_embeds_mask is not None: |
|
|
raise ValueError("cannot pass `negative_prompt_embeds_mask` without `negative_prompt_embeds`") |
|
|
|
|
|
if prompt_embeds_mask.shape[0] != prompt_embeds.shape[0]: |
|
|
raise ValueError("`prompt_embeds_mask` must have the same batch size as `prompt_embeds`") |
|
|
|
|
|
elif negative_prompt_embeds is not None and negative_prompt_embeds.shape[0] != prompt_embeds.shape[0]: |
|
|
raise ValueError("`negative_prompt_embeds` must have the same batch size as `prompt_embeds`") |
|
|
|
|
|
elif ( |
|
|
negative_prompt_embeds_mask is not None and negative_prompt_embeds_mask.shape[0] != prompt_embeds.shape[0] |
|
|
): |
|
|
raise ValueError("`negative_prompt_embeds_mask` must have the same batch size as `prompt_embeds`") |
|
|
|
|
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
self.check_inputs( |
|
|
prompt_embeds=block_state.prompt_embeds, |
|
|
prompt_embeds_mask=block_state.prompt_embeds_mask, |
|
|
negative_prompt_embeds=block_state.negative_prompt_embeds, |
|
|
negative_prompt_embeds_mask=block_state.negative_prompt_embeds_mask, |
|
|
) |
|
|
|
|
|
block_state.batch_size = block_state.prompt_embeds.shape[0] |
|
|
block_state.dtype = block_state.prompt_embeds.dtype |
|
|
|
|
|
_, seq_len, _ = block_state.prompt_embeds.shape |
|
|
|
|
|
block_state.prompt_embeds = block_state.prompt_embeds.repeat(1, block_state.num_images_per_prompt, 1) |
|
|
block_state.prompt_embeds = block_state.prompt_embeds.view( |
|
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 |
|
|
) |
|
|
|
|
|
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.repeat(1, block_state.num_images_per_prompt, 1) |
|
|
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask.view( |
|
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len |
|
|
) |
|
|
|
|
|
if block_state.negative_prompt_embeds is not None: |
|
|
_, seq_len, _ = block_state.negative_prompt_embeds.shape |
|
|
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.repeat( |
|
|
1, block_state.num_images_per_prompt, 1 |
|
|
) |
|
|
block_state.negative_prompt_embeds = block_state.negative_prompt_embeds.view( |
|
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len, -1 |
|
|
) |
|
|
|
|
|
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.repeat( |
|
|
1, block_state.num_images_per_prompt, 1 |
|
|
) |
|
|
block_state.negative_prompt_embeds_mask = block_state.negative_prompt_embeds_mask.view( |
|
|
block_state.batch_size * block_state.num_images_per_prompt, seq_len |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|
|
|
|
|
|
class QwenImageInputsDynamicStep(ModularPipelineBlocks): |
|
|
model_name = "qwenimage" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
image_latent_inputs: List[str] = ["image_latents"], |
|
|
additional_batch_inputs: List[str] = [], |
|
|
): |
|
|
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" |
|
|
|
|
|
This step handles multiple common tasks to prepare inputs for the denoising step: |
|
|
1. For encoded image latents, use it update height/width if None, patchifies, and expands batch size |
|
|
2. For additional_batch_inputs: Only expands batch dimensions to match final batch size |
|
|
|
|
|
This is a dynamic block that allows you to configure which inputs to process. |
|
|
|
|
|
Args: |
|
|
image_latent_inputs (List[str], optional): Names of image latent tensors to process. |
|
|
These will be used to determine height/width, patchified, and batch-expanded. Can be a single string or |
|
|
list of strings. Defaults to ["image_latents"]. Examples: ["image_latents"], ["control_image_latents"] |
|
|
additional_batch_inputs (List[str], optional): |
|
|
Names of additional conditional input tensors to expand batch size. These tensors will only have their |
|
|
batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. |
|
|
Defaults to []. Examples: ["processed_mask_image"] |
|
|
|
|
|
Examples: |
|
|
# Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() |
|
|
|
|
|
# Configure to process multiple image latent inputs |
|
|
QwenImageInputsDynamicStep(image_latent_inputs=["image_latents", "control_image_latents"]) |
|
|
|
|
|
# Configure to process image latents and additional batch inputs QwenImageInputsDynamicStep( |
|
|
image_latent_inputs=["image_latents"], additional_batch_inputs=["processed_mask_image"] |
|
|
) |
|
|
""" |
|
|
if not isinstance(image_latent_inputs, list): |
|
|
image_latent_inputs = [image_latent_inputs] |
|
|
if not isinstance(additional_batch_inputs, list): |
|
|
additional_batch_inputs = [additional_batch_inputs] |
|
|
|
|
|
self._image_latent_inputs = image_latent_inputs |
|
|
self._additional_batch_inputs = additional_batch_inputs |
|
|
super().__init__() |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
|
|
|
summary_section = ( |
|
|
"Input processing step that:\n" |
|
|
" 1. For image latent inputs: Updates height/width if None, patchifies latents, and expands batch size\n" |
|
|
" 2. For additional batch inputs: Expands batch dimensions to match final batch size" |
|
|
) |
|
|
|
|
|
|
|
|
inputs_info = "" |
|
|
if self._image_latent_inputs or self._additional_batch_inputs: |
|
|
inputs_info = "\n\nConfigured inputs:" |
|
|
if self._image_latent_inputs: |
|
|
inputs_info += f"\n - Image latent inputs: {self._image_latent_inputs}" |
|
|
if self._additional_batch_inputs: |
|
|
inputs_info += f"\n - Additional batch inputs: {self._additional_batch_inputs}" |
|
|
|
|
|
|
|
|
placement_section = "\n\nThis block should be placed after the encoder steps and the text input step." |
|
|
|
|
|
return summary_section + inputs_info + placement_section |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
inputs = [ |
|
|
InputParam(name="num_images_per_prompt", default=1), |
|
|
InputParam(name="batch_size", required=True), |
|
|
InputParam(name="height"), |
|
|
InputParam(name="width"), |
|
|
] |
|
|
|
|
|
|
|
|
for image_latent_input_name in self._image_latent_inputs: |
|
|
inputs.append(InputParam(name=image_latent_input_name)) |
|
|
|
|
|
|
|
|
for input_name in self._additional_batch_inputs: |
|
|
inputs.append(InputParam(name=input_name)) |
|
|
|
|
|
return inputs |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("pachifier", QwenImagePachifier, default_creation_method="from_config"), |
|
|
] |
|
|
|
|
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
|
|
|
for image_latent_input_name in self._image_latent_inputs: |
|
|
image_latent_tensor = getattr(block_state, image_latent_input_name) |
|
|
if image_latent_tensor is None: |
|
|
continue |
|
|
|
|
|
|
|
|
height, width = calculate_dimension_from_latents(image_latent_tensor, components.vae_scale_factor) |
|
|
block_state.height = block_state.height or height |
|
|
block_state.width = block_state.width or width |
|
|
|
|
|
|
|
|
image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) |
|
|
|
|
|
|
|
|
image_latent_tensor = repeat_tensor_to_batch_size( |
|
|
input_name=image_latent_input_name, |
|
|
input_tensor=image_latent_tensor, |
|
|
num_images_per_prompt=block_state.num_images_per_prompt, |
|
|
batch_size=block_state.batch_size, |
|
|
) |
|
|
|
|
|
setattr(block_state, image_latent_input_name, image_latent_tensor) |
|
|
|
|
|
|
|
|
for input_name in self._additional_batch_inputs: |
|
|
input_tensor = getattr(block_state, input_name) |
|
|
if input_tensor is None: |
|
|
continue |
|
|
|
|
|
|
|
|
input_tensor = repeat_tensor_to_batch_size( |
|
|
input_name=input_name, |
|
|
input_tensor=input_tensor, |
|
|
num_images_per_prompt=block_state.num_images_per_prompt, |
|
|
batch_size=block_state.batch_size, |
|
|
) |
|
|
|
|
|
setattr(block_state, input_name, input_tensor) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class QwenImageControlNetInputsStep(ModularPipelineBlocks): |
|
|
model_name = "qwenimage" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "prepare the `control_image_latents` for controlnet. Insert after all the other inputs steps." |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam(name="control_image_latents", required=True), |
|
|
InputParam(name="batch_size", required=True), |
|
|
InputParam(name="num_images_per_prompt", default=1), |
|
|
InputParam(name="height"), |
|
|
InputParam(name="width"), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
|
|
|
if isinstance(components.controlnet, QwenImageMultiControlNetModel): |
|
|
control_image_latents = [] |
|
|
|
|
|
for i, control_image_latents_ in enumerate(block_state.control_image_latents): |
|
|
|
|
|
height, width = calculate_dimension_from_latents(control_image_latents_, components.vae_scale_factor) |
|
|
block_state.height = block_state.height or height |
|
|
block_state.width = block_state.width or width |
|
|
|
|
|
|
|
|
control_image_latents_ = components.pachifier.pack_latents(control_image_latents_) |
|
|
|
|
|
|
|
|
control_image_latents_ = repeat_tensor_to_batch_size( |
|
|
input_name=f"control_image_latents[{i}]", |
|
|
input_tensor=control_image_latents_, |
|
|
num_images_per_prompt=block_state.num_images_per_prompt, |
|
|
batch_size=block_state.batch_size, |
|
|
) |
|
|
|
|
|
control_image_latents.append(control_image_latents_) |
|
|
|
|
|
block_state.control_image_latents = control_image_latents |
|
|
|
|
|
else: |
|
|
|
|
|
height, width = calculate_dimension_from_latents( |
|
|
block_state.control_image_latents, components.vae_scale_factor |
|
|
) |
|
|
block_state.height = block_state.height or height |
|
|
block_state.width = block_state.width or width |
|
|
|
|
|
|
|
|
block_state.control_image_latents = components.pachifier.pack_latents(block_state.control_image_latents) |
|
|
|
|
|
|
|
|
block_state.control_image_latents = repeat_tensor_to_batch_size( |
|
|
input_name="control_image_latents", |
|
|
input_tensor=block_state.control_image_latents, |
|
|
num_images_per_prompt=block_state.num_images_per_prompt, |
|
|
batch_size=block_state.batch_size, |
|
|
) |
|
|
|
|
|
block_state.control_image_latents = block_state.control_image_latents |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
|
|
|
return components, state |
|
|
|