Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from typing import Any | |
| import torch | |
| from ...configuration_utils import FrozenDict | |
| from ...guiders import ClassifierFreeGuidance | |
| from ...models import LTXVideoTransformer3DModel | |
| from ...schedulers import FlowMatchEulerDiscreteScheduler | |
| from ..modular_pipeline import ( | |
| BlockState, | |
| LoopSequentialPipelineBlocks, | |
| ModularPipelineBlocks, | |
| PipelineState, | |
| ) | |
| from ..modular_pipeline_utils import ComponentSpec, InputParam | |
| from .modular_pipeline import LTXModularPipeline, LTXVideoPachifier | |
| class LTXLoopBeforeDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def description(self) -> str: | |
| return ( | |
| "Step within the denoising loop that prepares the latent input for the denoiser. " | |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " | |
| "object (e.g. `LTXDenoiseLoopWrapper`)" | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("latents", required=True), | |
| InputParam.template("dtype", required=True), | |
| ] | |
| def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): | |
| block_state.latent_model_input = block_state.latents.to(block_state.dtype) | |
| return components, block_state | |
| class LTXLoopDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def __init__( | |
| self, | |
| guider_input_fields: dict[str, Any] | None = None, | |
| ): | |
| if guider_input_fields is None: | |
| guider_input_fields = { | |
| "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), | |
| "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), | |
| } | |
| if not isinstance(guider_input_fields, dict): | |
| raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") | |
| self._guider_input_fields = guider_input_fields | |
| super().__init__() | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 3.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ComponentSpec("transformer", LTXVideoTransformer3DModel), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "Step within the denoising loop that denoises the latents with guidance. " | |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " | |
| "object (e.g. `LTXDenoiseLoopWrapper`)" | |
| ) | |
| def inputs(self) -> list[tuple[str, Any]]: | |
| inputs = [ | |
| InputParam.template("attention_kwargs"), | |
| InputParam.template("num_inference_steps", required=True), | |
| InputParam("rope_interpolation_scale", type_hint=tuple), | |
| InputParam.template("height"), | |
| InputParam.template("width"), | |
| InputParam("num_frames", type_hint=int), | |
| ] | |
| guider_input_names = [] | |
| for value in self._guider_input_fields.values(): | |
| if isinstance(value, tuple): | |
| guider_input_names.extend(value) | |
| else: | |
| guider_input_names.append(value) | |
| for name in guider_input_names: | |
| inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) | |
| return inputs | |
| def __call__( | |
| self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor | |
| ) -> PipelineState: | |
| components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) | |
| latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 | |
| latent_height = block_state.height // components.vae_spatial_compression_ratio | |
| latent_width = block_state.width // components.vae_spatial_compression_ratio | |
| guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = guider_state_batch.as_dict() | |
| cond_kwargs = { | |
| k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v | |
| for k, v in cond_kwargs.items() | |
| if k in self._guider_input_fields.keys() | |
| } | |
| context_name = getattr(guider_state_batch, components.guider._identifier_key, None) | |
| with components.transformer.cache_context(context_name): | |
| guider_state_batch.noise_pred = components.transformer( | |
| hidden_states=block_state.latent_model_input, | |
| timestep=t.expand(block_state.latent_model_input.shape[0]).to(block_state.dtype), | |
| num_frames=latent_num_frames, | |
| height=latent_height, | |
| width=latent_width, | |
| rope_interpolation_scale=block_state.rope_interpolation_scale, | |
| attention_kwargs=block_state.attention_kwargs, | |
| return_dict=False, | |
| **cond_kwargs, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| block_state.noise_pred = components.guider(guider_state)[0] | |
| return components, block_state | |
| class LTXLoopAfterDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "Step within the denoising loop that updates the latents. " | |
| "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` " | |
| "object (e.g. `LTXDenoiseLoopWrapper`)" | |
| ) | |
| def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): | |
| latents_dtype = block_state.latents.dtype | |
| block_state.latents = components.scheduler.step( | |
| block_state.noise_pred, | |
| t, | |
| block_state.latents, | |
| return_dict=False, | |
| )[0] | |
| if block_state.latents.dtype != latents_dtype: | |
| block_state.latents = block_state.latents.to(latents_dtype) | |
| return components, block_state | |
| class LTXDenoiseLoopWrapper(LoopSequentialPipelineBlocks): | |
| model_name = "ltx" | |
| def description(self) -> str: | |
| return ( | |
| "Pipeline block that iteratively denoises the latents over `timesteps`. " | |
| "The specific steps within each iteration can be customized with `sub_blocks` attributes" | |
| ) | |
| def loop_expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), | |
| ComponentSpec("transformer", LTXVideoTransformer3DModel), | |
| ] | |
| def loop_inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("timesteps", required=True), | |
| InputParam.template("num_inference_steps", required=True), | |
| ] | |
| def __call__(self, components: LTXModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| block_state.num_warmup_steps = max( | |
| len(block_state.timesteps) - block_state.num_inference_steps * components.scheduler.order, 0 | |
| ) | |
| with self.progress_bar(total=block_state.num_inference_steps) as progress_bar: | |
| for i, t in enumerate(block_state.timesteps): | |
| components, block_state = self.loop_step(components, block_state, i=i, t=t) | |
| if i == len(block_state.timesteps) - 1 or ( | |
| (i + 1) > block_state.num_warmup_steps and (i + 1) % components.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| class LTXDenoiseStep(LTXDenoiseLoopWrapper): | |
| block_classes = [ | |
| LTXLoopBeforeDenoiser, | |
| LTXLoopDenoiser( | |
| guider_input_fields={ | |
| "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), | |
| "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), | |
| } | |
| ), | |
| LTXLoopAfterDenoiser, | |
| ] | |
| block_names = ["before_denoiser", "denoiser", "after_denoiser"] | |
| def description(self) -> str: | |
| return ( | |
| "Denoise step that iteratively denoises the latents.\n" | |
| "Its loop logic is defined in `LTXDenoiseLoopWrapper.__call__` method.\n" | |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" | |
| " - `LTXLoopBeforeDenoiser`\n" | |
| " - `LTXLoopDenoiser`\n" | |
| " - `LTXLoopAfterDenoiser`\n" | |
| "This block supports text-to-video tasks." | |
| ) | |
| class LTXImage2VideoLoopBeforeDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def description(self) -> str: | |
| return ( | |
| "Step within the i2v denoising loop that prepares the latent input and modulates " | |
| "the timestep with the conditioning mask." | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("latents", required=True), | |
| InputParam("conditioning_mask", required=True, type_hint=torch.Tensor), | |
| InputParam.template("dtype", required=True), | |
| ] | |
| def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): | |
| block_state.latent_model_input = block_state.latents.to(block_state.dtype) | |
| block_state.timestep_adjusted = t.expand(block_state.latent_model_input.shape[0]).unsqueeze(-1) * ( | |
| 1 - block_state.conditioning_mask | |
| ) | |
| return components, block_state | |
| class LTXImage2VideoLoopDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def __init__( | |
| self, | |
| guider_input_fields: dict[str, Any] | None = None, | |
| ): | |
| if guider_input_fields is None: | |
| guider_input_fields = { | |
| "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), | |
| "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), | |
| } | |
| if not isinstance(guider_input_fields, dict): | |
| raise ValueError(f"guider_input_fields must be a dictionary but is {type(guider_input_fields)}") | |
| self._guider_input_fields = guider_input_fields | |
| super().__init__() | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 3.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ComponentSpec("transformer", LTXVideoTransformer3DModel), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "Step within the i2v denoising loop that denoises the latents with guidance " | |
| "using timestep modulated by the conditioning mask." | |
| ) | |
| def inputs(self) -> list[tuple[str, Any]]: | |
| inputs = [ | |
| InputParam.template("attention_kwargs"), | |
| InputParam.template("num_inference_steps", required=True), | |
| InputParam("rope_interpolation_scale", type_hint=tuple), | |
| InputParam.template("height"), | |
| InputParam.template("width"), | |
| InputParam("num_frames", type_hint=int), | |
| ] | |
| guider_input_names = [] | |
| for value in self._guider_input_fields.values(): | |
| if isinstance(value, tuple): | |
| guider_input_names.extend(value) | |
| else: | |
| guider_input_names.append(value) | |
| for name in guider_input_names: | |
| inputs.append(InputParam(name=name, required=True, type_hint=torch.Tensor)) | |
| return inputs | |
| def __call__( | |
| self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor | |
| ) -> PipelineState: | |
| components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t) | |
| latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 | |
| latent_height = block_state.height // components.vae_spatial_compression_ratio | |
| latent_width = block_state.width // components.vae_spatial_compression_ratio | |
| guider_state = components.guider.prepare_inputs_from_block_state(block_state, self._guider_input_fields) | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = guider_state_batch.as_dict() | |
| cond_kwargs = { | |
| k: v.to(block_state.dtype) if isinstance(v, torch.Tensor) else v | |
| for k, v in cond_kwargs.items() | |
| if k in self._guider_input_fields.keys() | |
| } | |
| context_name = getattr(guider_state_batch, components.guider._identifier_key, None) | |
| with components.transformer.cache_context(context_name): | |
| guider_state_batch.noise_pred = components.transformer( | |
| hidden_states=block_state.latent_model_input, | |
| timestep=block_state.timestep_adjusted, | |
| num_frames=latent_num_frames, | |
| height=latent_height, | |
| width=latent_width, | |
| rope_interpolation_scale=block_state.rope_interpolation_scale, | |
| attention_kwargs=block_state.attention_kwargs, | |
| return_dict=False, | |
| **cond_kwargs, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| block_state.noise_pred = components.guider(guider_state)[0] | |
| return components, block_state | |
| class LTXImage2VideoLoopAfterDenoiser(ModularPipelineBlocks): | |
| model_name = "ltx" | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", FlowMatchEulerDiscreteScheduler), | |
| ComponentSpec( | |
| "pachifier", | |
| LTXVideoPachifier, | |
| config=FrozenDict({"patch_size": 1, "patch_size_t": 1}), | |
| default_creation_method="from_config", | |
| ), | |
| ] | |
| def description(self) -> str: | |
| return ( | |
| "Step within the i2v denoising loop that updates the latents, " | |
| "applying the scheduler step only to frames after the first (conditioned) frame." | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("height"), | |
| InputParam.template("width"), | |
| InputParam("num_frames", type_hint=int), | |
| ] | |
| def __call__(self, components: LTXModularPipeline, block_state: BlockState, i: int, t: torch.Tensor): | |
| latent_num_frames = (block_state.num_frames - 1) // components.vae_temporal_compression_ratio + 1 | |
| latent_height = block_state.height // components.vae_spatial_compression_ratio | |
| latent_width = block_state.width // components.vae_spatial_compression_ratio | |
| noise_pred = components.pachifier.unpack_latents( | |
| block_state.noise_pred, latent_num_frames, latent_height, latent_width | |
| ) | |
| latents = components.pachifier.unpack_latents( | |
| block_state.latents, latent_num_frames, latent_height, latent_width | |
| ) | |
| noise_pred = noise_pred[:, :, 1:] | |
| noise_latents = latents[:, :, 1:] | |
| pred_latents = components.scheduler.step(noise_pred, t, noise_latents, return_dict=False)[0] | |
| latents = torch.cat([latents[:, :, :1], pred_latents], dim=2) | |
| block_state.latents = components.pachifier.pack_latents(latents) | |
| return components, block_state | |
| class LTXImage2VideoDenoiseStep(LTXDenoiseLoopWrapper): | |
| block_classes = [ | |
| LTXImage2VideoLoopBeforeDenoiser, | |
| LTXImage2VideoLoopDenoiser( | |
| guider_input_fields={ | |
| "encoder_hidden_states": ("prompt_embeds", "negative_prompt_embeds"), | |
| "encoder_attention_mask": ("prompt_attention_mask", "negative_prompt_attention_mask"), | |
| } | |
| ), | |
| LTXImage2VideoLoopAfterDenoiser, | |
| ] | |
| block_names = ["before_denoiser", "denoiser", "after_denoiser"] | |
| def description(self) -> str: | |
| return ( | |
| "Denoise step for image-to-video that iteratively denoises the latents.\n" | |
| "The first frame is kept fixed via a conditioning mask.\n" | |
| "At each iteration, it runs blocks defined in `sub_blocks` sequentially:\n" | |
| " - `LTXImage2VideoLoopBeforeDenoiser`\n" | |
| " - `LTXImage2VideoLoopDenoiser`\n" | |
| " - `LTXImage2VideoLoopAfterDenoiser`" | |
| ) | |