| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Any |
|
|
| import torch |
|
|
| from ...models import FluxTransformer2DModel |
| from ...schedulers import FlowMatchEulerDiscreteScheduler |
| from ...utils import logging |
| from ..modular_pipeline import ( |
| BlockState, |
| LoopSequentialPipelineBlocks, |
| ModularPipelineBlocks, |
| PipelineState, |
| ) |
| from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam |
| from .modular_pipeline import FluxModularPipeline |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| class FluxLoopDenoiser(ModularPipelineBlocks): |
| model_name = "flux" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("transformer", FluxTransformer2DModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that denoise the latents. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `FluxDenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("joint_attention_kwargs"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", |
| ), |
| InputParam( |
| "guidance", |
| required=False, |
| type_hint=torch.Tensor, |
| description="Guidance scale as a tensor", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Prompt embeddings", |
| ), |
| InputParam( |
| "pooled_prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Pooled prompt embeddings", |
| ), |
| InputParam( |
| "txt_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="IDs computed from text sequence needed for RoPE", |
| ), |
| InputParam( |
| "img_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="IDs computed from image sequence needed for RoPE", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
| ) -> PipelineState: |
| noise_pred = components.transformer( |
| hidden_states=block_state.latents, |
| timestep=t.flatten() / 1000, |
| guidance=block_state.guidance, |
| encoder_hidden_states=block_state.prompt_embeds, |
| pooled_projections=block_state.pooled_prompt_embeds, |
| joint_attention_kwargs=block_state.joint_attention_kwargs, |
| txt_ids=block_state.txt_ids, |
| img_ids=block_state.img_ids, |
| return_dict=False, |
| )[0] |
| block_state.noise_pred = noise_pred |
|
|
| return components, block_state |
|
|
|
|
| class FluxKontextLoopDenoiser(ModularPipelineBlocks): |
| model_name = "flux-kontext" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("transformer", FluxTransformer2DModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Step within the denoising loop that denoise the latents for Flux Kontext. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `FluxDenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [ |
| InputParam("joint_attention_kwargs"), |
| InputParam( |
| "latents", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.", |
| ), |
| InputParam( |
| "image_latents", |
| type_hint=torch.Tensor, |
| description="Image latents to use for the denoising process. Can be generated in prepare_latent step.", |
| ), |
| InputParam( |
| "guidance", |
| required=False, |
| type_hint=torch.Tensor, |
| description="Guidance scale as a tensor", |
| ), |
| InputParam( |
| "prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Prompt embeddings", |
| ), |
| InputParam( |
| "pooled_prompt_embeds", |
| required=True, |
| type_hint=torch.Tensor, |
| description="Pooled prompt embeddings", |
| ), |
| InputParam( |
| "txt_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="IDs computed from text sequence needed for RoPE", |
| ), |
| InputParam( |
| "img_ids", |
| required=True, |
| type_hint=torch.Tensor, |
| description="IDs computed from latent sequence needed for RoPE", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor |
| ) -> PipelineState: |
| latents = block_state.latents |
| latent_model_input = latents |
| image_latents = block_state.image_latents |
| if image_latents is not None: |
| latent_model_input = torch.cat([latent_model_input, image_latents], dim=1) |
|
|
| timestep = t.expand(latents.shape[0]).to(latents.dtype) |
| noise_pred = components.transformer( |
| hidden_states=latent_model_input, |
| timestep=timestep / 1000, |
| guidance=block_state.guidance, |
| encoder_hidden_states=block_state.prompt_embeds, |
| pooled_projections=block_state.pooled_prompt_embeds, |
| joint_attention_kwargs=block_state.joint_attention_kwargs, |
| txt_ids=block_state.txt_ids, |
| img_ids=block_state.img_ids, |
| return_dict=False, |
| )[0] |
| noise_pred = noise_pred[:, : latents.size(1)] |
| block_state.noise_pred = noise_pred |
|
|
| return components, block_state |
|
|
|
|
| class FluxLoopAfterDenoiser(ModularPipelineBlocks): |
| model_name = "flux" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "step within the denoising loop that update the latents. " |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " |
| "object (e.g. `FluxDenoiseLoopWrapper`)" |
| ) |
|
|
| @property |
| def inputs(self) -> list[tuple[str, Any]]: |
| return [] |
|
|
| @property |
| def intermediate_inputs(self) -> list[str]: |
| return [InputParam("generator")] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")] |
|
|
| @torch.no_grad() |
| def __call__(self, components: FluxModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): |
| |
| latents_dtype = block_state.latents.dtype |
| block_state.latents = components.scheduler.step( |
| block_state.noise_pred, |
| t, |
| block_state.latents, |
| return_dict=False, |
| )[0] |
|
|
| if block_state.latents.dtype != latents_dtype: |
| block_state.latents = block_state.latents.to(latents_dtype) |
|
|
| return components, block_state |
|
|
|
|
| class FluxDenoiseLoopWrapper(LoopSequentialPipelineBlocks): |
| model_name = "flux" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Pipeline block that iteratively denoise the latents over `timesteps`. " |
| "The specific steps with each iteration can be customized with `sub_blocks` attributes" |
| ) |
|
|
| @property |
| def loop_expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), |
| ComponentSpec("transformer", FluxTransformer2DModel), |
| ] |
|
|
| @property |
| def loop_inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "timesteps", |
| required=True, |
| type_hint=torch.Tensor, |
| description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.", |
| ), |
| InputParam( |
| "num_inference_steps", |
| required=True, |
| type_hint=int, |
| description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__(self, components: FluxModularPipeline, state: PipelineState) -> PipelineState: |
| block_state = self.get_block_state(state) |
|
|
| block_state.num_warmup_steps = max( |
| len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 |
| ) |
| with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: |
| for i, t in enumerate(block_state.timesteps): |
| components, block_state = self.loop_step(components, block_state, i=i, t=t) |
| if i == len(block_state.timesteps) - 1 or ( |
| (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 |
| ): |
| progress_bar.update() |
|
|
| self.set_block_state(state, block_state) |
|
|
| return components, state |
|
|
|
|
| class FluxDenoiseStep(FluxDenoiseLoopWrapper): |
| block_classes = [FluxLoopDenoiser, FluxLoopAfterDenoiser] |
| block_names = ["denoiser", "after_denoiser"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoise the latents. \n" |
| "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" |
| " - `FluxLoopDenoiser`\n" |
| " - `FluxLoopAfterDenoiser`\n" |
| "This block supports both text2image and img2img tasks." |
| ) |
|
|
|
|
| class FluxKontextDenoiseStep(FluxDenoiseLoopWrapper): |
| model_name = "flux-kontext" |
| block_classes = [FluxKontextLoopDenoiser, FluxLoopAfterDenoiser] |
| block_names = ["denoiser", "after_denoiser"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoise step that iteratively denoise the latents. \n" |
| "Its loop logic is defined in `FluxDenoiseLoopWrapper.__call__` method \n" |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" |
| " - `FluxKontextLoopDenoiser`\n" |
| " - `FluxLoopAfterDenoiser`\n" |
| "This block supports both text2image and img2img tasks." |
| ) |
|
|