| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any |
|
|
| import torch |
|
|
| from ...configuration_utils import FrozenDict |
| from ...guiders import ClassifierFreeGuidance |
| from ...models import Flux2Transformer2DModel |
| from ...schedulers import FlowMatchEulerDiscreteScheduler |
| from ...utils import is_torch_xla_available, logging |
| from ..modular_pipeline import ( |
| BlockState, |
| LoopSequentialPipelineBlocks, |
| ModularPipelineBlocks, |
| PipelineState, |
| ) |
| from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam |
| from .modular_pipeline import Flux2KleinModularPipeline, Flux2ModularPipeline |
|
|
|
|
| if is_torch_xla_available(): |
| import torch_xla.core.xla_model as xm |
|
|
| XLA_AVAILABLE = True |
| else: |
| XLA_AVAILABLE = False |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class Flux2LoopDenoiser(ModularPipelineBlocks): |
| model_name = "flux2" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("transformer", Flux2Transformer2DModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that denoises the latents for Flux2. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `Flux2DenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("joint_attention_kwargs"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The latents to denoise. Shape: (B, seq_len, C)", |
| ), |
| InputParam( |
| "image_latents", |
| type_hint=torch.Tensor, |
| description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", |
| ), |
| InputParam( |
| "image_latent_ids", |
| type_hint=torch.Tensor, |
| description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", |
| ), |
| InputParam( |
| "guidance", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Guidance scale as a tensor", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Text embeddings from Mistral3", |
| ), |
| InputParam( |
| "txt_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for text tokens (T, H, W, L)", |
| ), |
| InputParam( |
| "latent_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for latent tokens (T, H, W, L)", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
| ) -> PipelineState: |
| latents = block_state.latents |
| latent_model_input = latents.to(components.transformer.dtype) |
| img_ids = block_state.latent_ids |
|
|
| image_latents = getattr(block_state, "image_latents", None) |
| if image_latents is not None: |
| latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) |
| image_latent_ids = block_state.image_latent_ids |
| img_ids = torch.cat([img_ids, image_latent_ids], dim=1) |
|
|
| timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
| noise_pred = components.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep / 1000, |
| guidance=block_state.guidance, |
| encoder_hidden_states=block_state.prompt_embeds, |
| txt_ids=block_state.txt_ids, |
| img_ids=img_ids, |
| joint_attention_kwargs=block_state.joint_attention_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| noise_pred = noise_pred[:, : latents.size(1)] |
| block_state.noise_pred = noise_pred |
|
|
| return components, block_state |
|
|
|
|
| |
| class Flux2KleinLoopDenoiser(ModularPipelineBlocks): |
| model_name = "flux2-klein" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("transformer", Flux2Transformer2DModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that denoises the latents for Flux2. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `Flux2DenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("joint_attention_kwargs"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The latents to denoise. Shape: (B, seq_len, C)", |
| ), |
| InputParam( |
| "image_latents", |
| type_hint=torch.Tensor, |
| description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", |
| ), |
| InputParam( |
| "image_latent_ids", |
| type_hint=torch.Tensor, |
| description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Text embeddings from Qwen3", |
| ), |
| InputParam( |
| "txt_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for text tokens (T, H, W, L)", |
| ), |
| InputParam( |
| "latent_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for latent tokens (T, H, W, L)", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
| ) -> PipelineState: |
| latents = block_state.latents |
| latent_model_input = latents.to(components.transformer.dtype) |
| img_ids = block_state.latent_ids |
|
|
| image_latents = getattr(block_state, "image_latents", None) |
| if image_latents is not None: |
| latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) |
| image_latent_ids = block_state.image_latent_ids |
| img_ids = torch.cat([img_ids, image_latent_ids], dim=1) |
|
|
| timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
| noise_pred = components.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep / 1000, |
| guidance=None, |
| encoder_hidden_states=block_state.prompt_embeds, |
| txt_ids=block_state.txt_ids, |
| img_ids=img_ids, |
| joint_attention_kwargs=block_state.joint_attention_kwargs, |
| return_dict=False, |
| )[0] |
|
|
| noise_pred = noise_pred[:, : latents.size(1)] |
| block_state.noise_pred = noise_pred |
|
|
| return components, block_state |
|
|
|
|
| |
| class Flux2KleinBaseLoopDenoiser(ModularPipelineBlocks): |
| model_name = "flux2-klein" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("transformer", Flux2Transformer2DModel), |
| ComponentSpec( |
| "guider", |
| ClassifierFreeGuidance, |
| config=FrozenDict({"guidance_scale": 4.0}), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def expected_configs(self) -> list[ConfigSpec]: |
| return [ |
| ConfigSpec(name="is_distilled", default=False), |
| ] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that denoises the latents for Flux2. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `Flux2DenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("joint_attention_kwargs"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The latents to denoise. Shape: (B, seq_len, C)", |
| ), |
| InputParam( |
| "image_latents", |
| type_hint=torch.Tensor, |
| description="Packed image latents for conditioning. Shape: (B, img_seq_len, C)", |
| ), |
| InputParam( |
| "image_latent_ids", |
| type_hint=torch.Tensor, |
| description="Position IDs for image latents. Shape: (B, img_seq_len, 4)", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Text embeddings from Qwen3", |
| ), |
| InputParam( |
| "negative_prompt_embeds", |
| required=False, |
| type_hint=torch.Tensor, |
| description="Negative text embeddings from Qwen3", |
| ), |
| InputParam( |
| "txt_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for text tokens (T, H, W, L)", |
| ), |
| InputParam( |
| "negative_txt_ids", |
| required=False, |
| type_hint=torch.Tensor, |
| description="4D position IDs for negative text tokens (T, H, W, L)", |
| ), |
| InputParam( |
| "latent_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="4D position IDs for latent tokens (T, H, W, L)", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: Flux2KleinModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
| ) -> PipelineState: |
| latents = block_state.latents |
| latent_model_input = latents.to(components.transformer.dtype) |
| img_ids = block_state.latent_ids |
|
|
| image_latents = getattr(block_state, "image_latents", None) |
| if image_latents is not None: |
| latent_model_input = torch.cat([latents, image_latents], dim=1).to(components.transformer.dtype) |
| image_latent_ids = block_state.image_latent_ids |
| img_ids = torch.cat([img_ids, image_latent_ids], dim=1) |
|
|
| timestep = t.expand(latents.shape[0]).to(latents.dtype) |
|
|
| guider_inputs = { |
| "encoder_hidden_states": ( |
| getattr(block_state, "prompt_embeds", None), |
| getattr(block_state, "negative_prompt_embeds", None), |
| ), |
| "txt_ids": ( |
| getattr(block_state, "txt_ids", None), |
| getattr(block_state, "negative_txt_ids", None), |
| ), |
| } |
|
|
| components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) |
| guider_state = components.guider.prepare_inputs(guider_inputs) |
|
|
| for guider_state_batch in guider_state: |
| components.guider.prepare_models(components.transformer) |
| cond_kwargs = {input_name: getattr(guider_state_batch, input_name) for input_name in guider_inputs.keys()} |
|
|
| noise_pred = components.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep / 1000, |
| guidance=None, |
| img_ids=img_ids, |
| joint_attention_kwargs=block_state.joint_attention_kwargs, |
| return_dict=False, |
| **cond_kwargs, |
| )[0] |
| guider_state_batch.noise_pred = noise_pred[:, : latents.size(1)] |
| components.guider.cleanup_models(components.transformer) |
|
|
| |
| block_state.noise_pred = components.guider(guider_state)[0] |
|
|
| return components, block_state |
|
|
|
|
| class Flux2LoopAfterDenoiser(ModularPipelineBlocks): |
| model_name = "flux2" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that updates the latents after denoising. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `Flux2DenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [] |
|
|
| @property |
| def intermediate_inputs(self) -> list[str]: |
| return [InputParam("generator")] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] |
|
|
| @torch.no_grad() |
| def __call__(self, components: Flux2ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): |
| latents_dtype = block_state.latents.dtype |
| block_state.latents = components.scheduler.step( |
| block_state.noise_pred, |
| t, |
| block_state.latents, |
| return_dict=False, |
| )[0] |
|
|
| if block_state.latents.dtype != latents_dtype: |
| if torch.backends.mps.is_available(): |
| block_state.latents = block_state.latents.to(latents_dtype) |
|
|
| return components, block_state |
|
|
|
|
| class Flux2DenoiseLoopWrapper(LoopSequentialPipelineBlocks): |
| model_name = "flux2" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Pipeline block that iteratively denoises the latents over `timesteps`. " |
| "The specific steps within each iteration can be customized with `sub_blocks` attribute" |
| ) |
|
|
| @property |
| def loop_expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), |
| ComponentSpec("transformer", Flux2Transformer2DModel), |
| ] |
|
|
| @property |
| def loop_inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "timesteps", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The timesteps to use for the denoising process.", |
| ), |
| InputParam( |
| "num_inference_steps", |
| required=True, |
| type_hint=int, |
| description="The number of inference steps to use for the denoising process.", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: Flux2ModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| block_state.num_warmup_steps = max( |
| len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 |
| ) |
|
|
| with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: |
| for i, t in enumerate(block_state.timesteps): |
| components, block_state = self.loop_step(components, block_state, i=i, t=t) |
|
|
| if i == len(block_state.timesteps) - 1 or ( |
| (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 |
| ): |
| progress_bar.update() |
|
|
| if XLA_AVAILABLE: |
| xm.mark_step() |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| class Flux2DenoiseStep(Flux2DenoiseLoopWrapper): |
| block_classes = [Flux2LoopDenoiser, Flux2LoopAfterDenoiser] |
| block_names = ["denoiser", "after_denoiser"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoises the latents for Flux2. \n" |
| "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" |
| " - `Flux2LoopDenoiser`\n" |
| " - `Flux2LoopAfterDenoiser`\n" |
| "This block supports both text-to-image and image-conditioned generation." |
| ) |
|
|
|
|
| class Flux2KleinDenoiseStep(Flux2DenoiseLoopWrapper): |
| block_classes = [Flux2KleinLoopDenoiser, Flux2LoopAfterDenoiser] |
| block_names = ["denoiser", "after_denoiser"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoises the latents for Flux2. \n" |
| "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" |
| " - `Flux2KleinLoopDenoiser`\n" |
| " - `Flux2LoopAfterDenoiser`\n" |
| "This block supports both text-to-image and image-conditioned generation." |
| ) |
|
|
|
|
| class Flux2KleinBaseDenoiseStep(Flux2DenoiseLoopWrapper): |
| block_classes = [Flux2KleinBaseLoopDenoiser, Flux2LoopAfterDenoiser] |
| block_names = ["denoiser", "after_denoiser"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoises the latents for Flux2. \n" |
| "Its loop logic is defined in `Flux2DenoiseLoopWrapper.__call__` method \n" |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" |
| " - `Flux2KleinBaseLoopDenoiser`\n" |
| " - `Flux2LoopAfterDenoiser`\n" |
| "This block supports both text-to-image and image-conditioned generation." |
| ) |
|
|