Spaces:
Running on Zero
Running on Zero
| # Copyright 2025 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| from tqdm.auto import tqdm | |
| from ...configuration_utils import FrozenDict | |
| from ...guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance | |
| from ...models import HeliosTransformer3DModel | |
| from ...schedulers import HeliosScheduler | |
| from ...utils import logging | |
| from ...utils.torch_utils import randn_tensor | |
| from ..modular_pipeline import ( | |
| BlockState, | |
| LoopSequentialPipelineBlocks, | |
| ModularPipelineBlocks, | |
| PipelineState, | |
| ) | |
| from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam | |
| from .before_denoise import calculate_shift | |
| from .modular_pipeline import HeliosModularPipeline | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| def sample_block_noise( | |
| batch_size, | |
| channel, | |
| num_frames, | |
| height, | |
| width, | |
| gamma, | |
| patch_size=(1, 2, 2), | |
| device=None, | |
| generator=None, | |
| ): | |
| """Generate spatially-correlated block noise for pyramid upsampling correction. | |
| Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure, | |
| matching the upsampling artifacts that need correction. | |
| """ | |
| # NOTE: A generator must be provided to ensure correct and reproducible results. | |
| # Creating a default generator here is a fallback only — without a fixed seed, | |
| # the output will be non-deterministic and may produce incorrect results in CP context. | |
| if generator is None: | |
| generator = torch.Generator(device=device) | |
| elif isinstance(generator, list): | |
| generator = generator[0] | |
| _, ph, pw = patch_size | |
| block_size = ph * pw | |
| cov = ( | |
| torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma | |
| ) | |
| cov += torch.eye(block_size, device=device) * 1e-8 | |
| cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. | |
| L = torch.linalg.cholesky(cov) | |
| block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) | |
| z = torch.randn(block_number, block_size, device=generator.device, generator=generator).to(device) | |
| noise = z @ L.T | |
| noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) | |
| noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) | |
| return noise | |
| # ======================================== | |
| # Chunk Loop Leaf Blocks | |
| # ======================================== | |
| class HeliosChunkHistorySliceStep(ModularPipelineBlocks): | |
| """Slices history latents into short/mid/long for a T2V chunk. | |
| At k==0 with no image_latents, creates a zero prefix. Otherwise uses image_latents (either provided or captured | |
| from first chunk by HeliosChunkUpdateStep). | |
| """ | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return ( | |
| "T2V history slice: splits history into long/mid/short. At k==0 with no image_latents, " | |
| "creates a zero prefix; otherwise uses image_latents as prefix for short history." | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam( | |
| "keep_first_frame", | |
| default=True, | |
| type_hint=bool, | |
| description="Whether to keep the first frame as a prefix in history.", | |
| ), | |
| InputParam( | |
| "history_sizes", | |
| required=True, | |
| type_hint=list, | |
| description="Sizes of long/mid/short history buffers for temporal context.", | |
| ), | |
| InputParam( | |
| "history_latents", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| description="Accumulated history latents from previous chunks.", | |
| ), | |
| InputParam("latent_shape", required=True, type_hint=tuple), | |
| ] | |
| def intermediate_outputs(self) -> list[OutputParam]: | |
| return [] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| keep_first_frame = block_state.keep_first_frame | |
| history_sizes = block_state.history_sizes | |
| image_latents = block_state.image_latents | |
| device = components._execution_device | |
| batch_size, num_channels_latents, _, h_latent, w_latent = block_state.latent_shape | |
| if keep_first_frame: | |
| latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ | |
| :, :, -sum(history_sizes) : | |
| ].split(history_sizes, dim=2) | |
| if image_latents is None and k == 0: | |
| latents_prefix = torch.zeros( | |
| batch_size, | |
| num_channels_latents, | |
| 1, | |
| h_latent, | |
| w_latent, | |
| device=device, | |
| dtype=torch.float32, | |
| ) | |
| else: | |
| latents_prefix = image_latents | |
| latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) | |
| else: | |
| latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ | |
| :, :, -sum(history_sizes) : | |
| ].split(history_sizes, dim=2) | |
| block_state.latents_history_short = latents_history_short | |
| block_state.latents_history_mid = latents_history_mid | |
| block_state.latents_history_long = latents_history_long | |
| return components, block_state | |
| class HeliosI2VChunkHistorySliceStep(ModularPipelineBlocks): | |
| """Slices history latents into short/mid/long for an I2V chunk. | |
| Always uses image_latents as prefix (assumes history pre-seeded with fake_image_latents). | |
| """ | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return ( | |
| "I2V history slice: splits pre-seeded history into long/mid/short, " | |
| "always using image_latents as prefix for short history." | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam( | |
| "keep_first_frame", | |
| default=True, | |
| type_hint=bool, | |
| description="Whether to keep the first frame as a prefix in history.", | |
| ), | |
| InputParam( | |
| "history_sizes", | |
| required=True, | |
| type_hint=list, | |
| description="Sizes of long/mid/short history buffers for temporal context.", | |
| ), | |
| InputParam( | |
| "history_latents", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| description="Accumulated history latents from previous chunks.", | |
| ), | |
| InputParam( | |
| "image_latents", | |
| required=True, | |
| type_hint=torch.Tensor, | |
| description="First-frame latents used as prefix for short history.", | |
| ), | |
| ] | |
| def intermediate_outputs(self) -> list[OutputParam]: | |
| return [] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| keep_first_frame = block_state.keep_first_frame | |
| history_sizes = block_state.history_sizes | |
| image_latents = block_state.image_latents | |
| if keep_first_frame: | |
| latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ | |
| :, :, -sum(history_sizes) : | |
| ].split(history_sizes, dim=2) | |
| latents_history_short = torch.cat([image_latents, latents_history_1x], dim=2) | |
| else: | |
| latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ | |
| :, :, -sum(history_sizes) : | |
| ].split(history_sizes, dim=2) | |
| block_state.latents_history_short = latents_history_short | |
| block_state.latents_history_mid = latents_history_mid | |
| block_state.latents_history_long = latents_history_long | |
| return components, block_state | |
| class HeliosChunkNoiseGenStep(ModularPipelineBlocks): | |
| """Generates noise latents for a chunk using randn_tensor.""" | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return "Generates random noise latents at full resolution for a single chunk." | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam("latent_shape", required=True, type_hint=tuple), | |
| InputParam.template("generator"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| device = components._execution_device | |
| block_state.latents = randn_tensor( | |
| block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 | |
| ) | |
| return components, block_state | |
| class HeliosPyramidChunkNoiseGenStep(ModularPipelineBlocks): | |
| """Generates noise latents and downsamples to smallest pyramid level.""" | |
| model_name = "helios-pyramid" | |
| def description(self) -> str: | |
| return ( | |
| "Generates random noise at full resolution, then downsamples to the smallest " | |
| "pyramid level via bilinear interpolation." | |
| ) | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam("latent_shape", required=True, type_hint=tuple), | |
| InputParam( | |
| "pyramid_num_inference_steps_list", | |
| default=[10, 10, 10], | |
| type_hint=list, | |
| description="Number of denoising steps per pyramid stage.", | |
| ), | |
| InputParam.template("generator"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| device = components._execution_device | |
| batch_size, num_channels_latents, num_latent_frames, h_latent, w_latent = block_state.latent_shape | |
| latents = randn_tensor( | |
| block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 | |
| ) | |
| # Downsample to smallest pyramid level | |
| h, w = h_latent, w_latent | |
| latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_latent_frames, num_channels_latents, h, w) | |
| for _ in range(len(block_state.pyramid_num_inference_steps_list) - 1): | |
| h //= 2 | |
| w //= 2 | |
| latents = F.interpolate(latents, size=(h, w), mode="bilinear") * 2 | |
| block_state.latents = latents.reshape(batch_size, num_latent_frames, num_channels_latents, h, w).permute( | |
| 0, 2, 1, 3, 4 | |
| ) | |
| return components, block_state | |
| class HeliosChunkSchedulerResetStep(ModularPipelineBlocks): | |
| """Resets the scheduler with timesteps for a single chunk.""" | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return "Resets the scheduler with the correct timesteps and shift parameter (mu) for this chunk." | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("scheduler", HeliosScheduler), | |
| ] | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam("mu", required=True, type_hint=float), | |
| InputParam.template("sigmas", required=True), | |
| InputParam.template("num_inference_steps"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| device = components._execution_device | |
| components.scheduler.set_timesteps( | |
| block_state.num_inference_steps, device=device, sigmas=block_state.sigmas, mu=block_state.mu | |
| ) | |
| block_state.timesteps = components.scheduler.timesteps | |
| return components, block_state | |
| # ======================================== | |
| # Inner Denoising Blocks | |
| # ======================================== | |
| class HeliosChunkDenoiseInner(ModularPipelineBlocks): | |
| """Inner timestep loop for denoising a single chunk, using guider for guidance.""" | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return ( | |
| "Inner denoising loop that iterates over timesteps for a single chunk. " | |
| "Uses the guider to manage conditional/unconditional forward passes with cache_context, " | |
| "applies guidance, and runs scheduler step." | |
| ) | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("transformer", HeliosTransformer3DModel), | |
| ComponentSpec("scheduler", HeliosScheduler), | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 5.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ] | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("latents"), | |
| InputParam.template("timesteps"), | |
| InputParam("prompt_embeds", type_hint=torch.Tensor), | |
| InputParam("negative_prompt_embeds", type_hint=torch.Tensor), | |
| InputParam.template("denoiser_input_fields"), | |
| InputParam.template("num_inference_steps"), | |
| InputParam.template("attention_kwargs"), | |
| InputParam.template("generator"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| latents = block_state.latents | |
| timesteps = block_state.timesteps | |
| num_inference_steps = block_state.num_inference_steps | |
| transformer_dtype = components.transformer.dtype | |
| num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order | |
| # Guider inputs: only encoder_hidden_states differs between cond/uncond | |
| guider_inputs = { | |
| "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), | |
| } | |
| # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) | |
| transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) | |
| shared_kwargs = {} | |
| for field_name, field_value in block_state.denoiser_input_fields.items(): | |
| if field_name in transformer_args and field_name not in guider_inputs: | |
| shared_kwargs[field_name] = field_value | |
| # Add loop-internal history latents with dtype casting | |
| shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) | |
| shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) | |
| shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) | |
| shared_kwargs["attention_kwargs"] = block_state.attention_kwargs | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| timestep = t.expand(latents.shape[0]).to(torch.int64) | |
| latent_model_input = latents.to(transformer_dtype) | |
| components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) | |
| guider_state = components.guider.prepare_inputs(guider_inputs) | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} | |
| context_name = getattr(guider_state_batch, components.guider._identifier_key) | |
| with components.transformer.cache_context(context_name): | |
| guider_state_batch.noise_pred = components.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| return_dict=False, | |
| **cond_kwargs, | |
| **shared_kwargs, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| noise_pred = components.guider(guider_state)[0] | |
| # Scheduler step | |
| latents = components.scheduler.step( | |
| noise_pred, | |
| t, | |
| latents, | |
| generator=block_state.generator, | |
| return_dict=False, | |
| )[0] | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| block_state.latents = latents | |
| return components, block_state | |
| class HeliosPyramidChunkDenoiseInner(ModularPipelineBlocks): | |
| """Nested pyramid stage loop with inner timestep denoising. | |
| For each pyramid stage (small -> full resolution): | |
| 1. Upsample latents + block noise correction (stages > 0) | |
| 2. Compute mu from current resolution, set scheduler timesteps | |
| 3. Run timestep denoising loop (same logic as HeliosChunkDenoiseInner) | |
| """ | |
| model_name = "helios-pyramid" | |
| def description(self) -> str: | |
| return ( | |
| "Pyramid denoising inner block: loops over pyramid stages from smallest to full resolution. " | |
| "Each stage upsamples latents (with block noise correction), recomputes scheduler parameters, " | |
| "and runs the timestep denoising loop." | |
| ) | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("transformer", HeliosTransformer3DModel), | |
| ComponentSpec("scheduler", HeliosScheduler), | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeZeroStarGuidance, | |
| config=FrozenDict({"guidance_scale": 5.0, "zero_init_steps": 2}), | |
| default_creation_method="from_config", | |
| ), | |
| ] | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("latents"), | |
| InputParam("prompt_embeds", type_hint=torch.Tensor), | |
| InputParam("negative_prompt_embeds", type_hint=torch.Tensor), | |
| InputParam.template("denoiser_input_fields"), | |
| InputParam( | |
| "pyramid_num_inference_steps_list", | |
| default=[10, 10, 10], | |
| type_hint=list, | |
| description="Number of denoising steps per pyramid stage.", | |
| ), | |
| InputParam.template("attention_kwargs"), | |
| InputParam.template("generator"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| device = components._execution_device | |
| transformer_dtype = components.transformer.dtype | |
| latents = block_state.latents | |
| pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) | |
| # Guider inputs: only encoder_hidden_states differs between cond/uncond | |
| guider_inputs = { | |
| "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), | |
| } | |
| # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) | |
| transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) | |
| shared_kwargs = {} | |
| for field_name, field_value in block_state.denoiser_input_fields.items(): | |
| if field_name in transformer_args and field_name not in guider_inputs: | |
| shared_kwargs[field_name] = field_value | |
| # Add loop-internal history latents with dtype casting | |
| shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) | |
| shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) | |
| shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) | |
| shared_kwargs["attention_kwargs"] = block_state.attention_kwargs | |
| # Save original zero_init_steps if the guider supports it (e.g. ClassifierFreeZeroStarGuidance). | |
| # Helios only applies zero init in pyramid stage 0 (lowest resolution), so we disable it | |
| # for subsequent stages by temporarily setting zero_init_steps=0. | |
| orig_zero_init_steps = getattr(components.guider, "zero_init_steps", None) | |
| for i_s in range(pyramid_num_stages): | |
| # --- Stage setup --- | |
| # Disable zero init for stages > 0 (only stage 0 should have zero init) | |
| if orig_zero_init_steps is not None and i_s > 0: | |
| components.guider.zero_init_steps = 0 | |
| # a. Compute mu from current resolution (before upsample, matching standard pipeline) | |
| patch_size = components.transformer.config.patch_size | |
| image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( | |
| patch_size[0] * patch_size[1] * patch_size[2] | |
| ) | |
| mu = calculate_shift( | |
| image_seq_len, | |
| components.scheduler.config.get("base_image_seq_len", 256), | |
| components.scheduler.config.get("max_image_seq_len", 4096), | |
| components.scheduler.config.get("base_shift", 0.5), | |
| components.scheduler.config.get("max_shift", 1.15), | |
| ) | |
| # b. Set scheduler timesteps for this stage | |
| num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] | |
| components.scheduler.set_timesteps( | |
| num_inference_steps, | |
| i_s, | |
| device=device, | |
| mu=mu, | |
| ) | |
| timesteps = components.scheduler.timesteps | |
| # c. Upsample + block noise correction for stages > 0 | |
| if i_s > 0: | |
| batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape | |
| new_h = current_h * 2 | |
| new_w = current_w * 2 | |
| latents = latents.permute(0, 2, 1, 3, 4).reshape( | |
| batch_size * num_frames, num_channels_latents, current_h, current_w | |
| ) | |
| latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") | |
| latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( | |
| 0, 2, 1, 3, 4 | |
| ) | |
| # Block noise correction | |
| ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] | |
| gamma = components.scheduler.config.gamma | |
| alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) | |
| beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) | |
| batch_size, num_channels_latents, num_frames, h, w = latents.shape | |
| noise = sample_block_noise( | |
| batch_size, | |
| num_channels_latents, | |
| num_frames, | |
| h, | |
| w, | |
| gamma, | |
| patch_size, | |
| device=device, | |
| generator=block_state.generator, | |
| ) | |
| noise = noise.to(dtype=transformer_dtype) | |
| latents = alpha * latents + beta * noise | |
| # --- Timestep denoising loop --- | |
| num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| timestep = t.expand(latents.shape[0]).to(torch.int64) | |
| latent_model_input = latents.to(transformer_dtype) | |
| components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) | |
| guider_state = components.guider.prepare_inputs(guider_inputs) | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = {kk: getattr(guider_state_batch, kk) for kk in guider_inputs.keys()} | |
| context_name = getattr(guider_state_batch, components.guider._identifier_key) | |
| with components.transformer.cache_context(context_name): | |
| guider_state_batch.noise_pred = components.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| return_dict=False, | |
| **cond_kwargs, | |
| **shared_kwargs, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| noise_pred = components.guider(guider_state)[0] | |
| # Scheduler step | |
| latents = components.scheduler.step( | |
| noise_pred, | |
| t, | |
| latents, | |
| generator=block_state.generator, | |
| return_dict=False, | |
| )[0] | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| # Restore original zero_init_steps | |
| if orig_zero_init_steps is not None: | |
| components.guider.zero_init_steps = orig_zero_init_steps | |
| block_state.latents = latents | |
| return components, block_state | |
| # ======================================== | |
| # Post-Denoise Update | |
| # ======================================== | |
| class HeliosChunkUpdateStep(ModularPipelineBlocks): | |
| """Updates chunk collection and history after denoising a single chunk.""" | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return ( | |
| "Post-denoising update step: appends the denoised latents to the chunk list, " | |
| "captures image_latents from the first chunk if needed, and extends history_latents." | |
| ) | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [] | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam("latents", type_hint=torch.Tensor), | |
| InputParam("history_latents", type_hint=torch.Tensor), | |
| InputParam("keep_first_frame", default=True, type_hint=bool), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| # e. Collect denoised latents for this chunk | |
| block_state.latent_chunks.append(block_state.latents) | |
| # f. Update history | |
| if block_state.keep_first_frame and k == 0 and block_state.image_latents is None: | |
| block_state.image_latents = block_state.latents[:, :, 0:1, :, :] | |
| block_state.history_latents = torch.cat([block_state.history_latents, block_state.latents], dim=2) | |
| return components, block_state | |
| # ======================================== | |
| # Chunk Loop Wrapper | |
| # ======================================== | |
| class HeliosChunkLoopWrapper(LoopSequentialPipelineBlocks): | |
| """Outer chunk loop that iterates over temporal chunks. | |
| History indices, scheduler params, and history state are prepared by HeliosPrepareHistoryStep and | |
| HeliosSetTimestepsStep before this block runs. Sub-blocks handle per-chunk preparation, denoising, and history | |
| updates. | |
| """ | |
| model_name = "helios" | |
| def description(self) -> str: | |
| return ( | |
| "Pipeline block that iterates over temporal chunks for progressive video generation. " | |
| "At each chunk iteration, it runs sub-blocks for preparation, denoising, and history updates." | |
| ) | |
| def loop_inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam("num_latent_chunk", required=True, type_hint=int), | |
| ] | |
| def loop_intermediate_outputs(self) -> list[OutputParam]: | |
| return [ | |
| OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: | |
| block_state = self.get_block_state(state) | |
| block_state.latent_chunks = [] | |
| if not hasattr(block_state, "image_latents"): | |
| block_state.image_latents = None | |
| for k in range(block_state.num_latent_chunk): | |
| components, block_state = self.loop_step(components, block_state, k=k) | |
| self.set_block_state(state, block_state) | |
| return components, state | |
| # ======================================== | |
| # Composed Chunk Denoise Steps | |
| # ======================================== | |
| class HeliosChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """T2V chunk-based denoising: history slice -> noise gen -> scheduler reset -> denoise -> update.""" | |
| block_classes = [ | |
| HeliosChunkHistorySliceStep, | |
| HeliosChunkNoiseGenStep, | |
| HeliosChunkSchedulerResetStep, | |
| HeliosChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "T2V chunk denoise step that iterates over temporal chunks.\n" | |
| "At each chunk: history_slice -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." | |
| ) | |
| class HeliosI2VChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """I2V chunk-based denoising: I2V history slice -> noise gen -> scheduler reset -> denoise -> update.""" | |
| block_classes = [ | |
| HeliosI2VChunkHistorySliceStep, | |
| HeliosChunkNoiseGenStep, | |
| HeliosChunkSchedulerResetStep, | |
| HeliosChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "I2V chunk denoise step that iterates over temporal chunks.\n" | |
| "At each chunk: history_slice (I2V) -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." | |
| ) | |
| class HeliosPyramidDistilledChunkDenoiseInner(ModularPipelineBlocks): | |
| """Nested pyramid stage loop with DMD denoising for distilled checkpoints. | |
| Same progressive multi-resolution strategy as HeliosPyramidChunkDenoiseInner, but: | |
| - Guidance is disabled (guidance_scale=1.0, no unconditional pass) | |
| - Supports is_amplify_first_chunk (doubles first chunk's timesteps via scheduler) | |
| - Tracks start_point_list and passes DMD-specific args to scheduler.step() | |
| """ | |
| model_name = "helios-pyramid" | |
| def description(self) -> str: | |
| return ( | |
| "Distilled pyramid denoising inner block for DMD checkpoints. Loops over pyramid stages " | |
| "from smallest to full resolution with guidance disabled and DMD scheduler support." | |
| ) | |
| def expected_components(self) -> list[ComponentSpec]: | |
| return [ | |
| ComponentSpec("transformer", HeliosTransformer3DModel), | |
| ComponentSpec("scheduler", HeliosScheduler), | |
| ComponentSpec( | |
| "guider", | |
| ClassifierFreeGuidance, | |
| config=FrozenDict({"guidance_scale": 1.0}), | |
| default_creation_method="from_config", | |
| ), | |
| ] | |
| def inputs(self) -> list[InputParam]: | |
| return [ | |
| InputParam.template("latents"), | |
| InputParam("prompt_embeds", type_hint=torch.Tensor), | |
| InputParam("negative_prompt_embeds", type_hint=torch.Tensor), | |
| InputParam.template("denoiser_input_fields"), | |
| InputParam( | |
| "pyramid_num_inference_steps_list", | |
| default=[2, 2, 2], | |
| type_hint=list, | |
| description="Number of denoising steps per pyramid stage.", | |
| ), | |
| InputParam( | |
| "is_amplify_first_chunk", | |
| default=True, | |
| type_hint=bool, | |
| description="Whether to double the first chunk's timesteps via the scheduler for amplified generation.", | |
| ), | |
| InputParam.template("attention_kwargs"), | |
| InputParam.template("generator"), | |
| ] | |
| def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): | |
| device = components._execution_device | |
| transformer_dtype = components.transformer.dtype | |
| latents = block_state.latents | |
| pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) | |
| is_first_chunk = k == 0 | |
| # Track start points for DMD scheduler | |
| start_point_list = [latents] | |
| # Guider inputs: only encoder_hidden_states differs between cond/uncond | |
| guider_inputs = { | |
| "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), | |
| } | |
| # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) | |
| transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) | |
| shared_kwargs = {} | |
| for field_name, field_value in block_state.denoiser_input_fields.items(): | |
| if field_name in transformer_args and field_name not in guider_inputs: | |
| shared_kwargs[field_name] = field_value | |
| # Add loop-internal history latents with dtype casting | |
| shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) | |
| shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) | |
| shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) | |
| shared_kwargs["attention_kwargs"] = block_state.attention_kwargs | |
| for i_s in range(pyramid_num_stages): | |
| # --- Stage setup --- | |
| patch_size = components.transformer.config.patch_size | |
| # a. Compute mu from current resolution (before upsample, matching standard pipeline) | |
| image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( | |
| patch_size[0] * patch_size[1] * patch_size[2] | |
| ) | |
| mu = calculate_shift( | |
| image_seq_len, | |
| components.scheduler.config.get("base_image_seq_len", 256), | |
| components.scheduler.config.get("max_image_seq_len", 4096), | |
| components.scheduler.config.get("base_shift", 0.5), | |
| components.scheduler.config.get("max_shift", 1.15), | |
| ) | |
| # b. Set scheduler timesteps for this stage (with DMD amplification) | |
| num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] | |
| components.scheduler.set_timesteps( | |
| num_inference_steps, | |
| i_s, | |
| device=device, | |
| mu=mu, | |
| is_amplify_first_chunk=block_state.is_amplify_first_chunk and is_first_chunk, | |
| ) | |
| timesteps = components.scheduler.timesteps | |
| # c. Upsample + block noise correction for stages > 0 | |
| if i_s > 0: | |
| batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape | |
| new_h = current_h * 2 | |
| new_w = current_w * 2 | |
| latents = latents.permute(0, 2, 1, 3, 4).reshape( | |
| batch_size * num_frames, num_channels_latents, current_h, current_w | |
| ) | |
| latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") | |
| latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( | |
| 0, 2, 1, 3, 4 | |
| ) | |
| # Block noise correction | |
| ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] | |
| gamma = components.scheduler.config.gamma | |
| alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) | |
| beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) | |
| batch_size, num_channels_latents, num_frames, h, w = latents.shape | |
| noise = sample_block_noise( | |
| batch_size, | |
| num_channels_latents, | |
| num_frames, | |
| h, | |
| w, | |
| gamma, | |
| patch_size, | |
| device=device, | |
| generator=block_state.generator, | |
| ) | |
| noise = noise.to(dtype=transformer_dtype) | |
| latents = alpha * latents + beta * noise | |
| start_point_list.append(latents) | |
| # --- Timestep denoising loop --- | |
| num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order | |
| with tqdm(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| timestep = t.expand(latents.shape[0]).to(torch.int64) | |
| latent_model_input = latents.to(transformer_dtype) | |
| components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) | |
| guider_state = components.guider.prepare_inputs(guider_inputs) | |
| for guider_state_batch in guider_state: | |
| components.guider.prepare_models(components.transformer) | |
| cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} | |
| context_name = getattr(guider_state_batch, components.guider._identifier_key) | |
| with components.transformer.cache_context(context_name): | |
| guider_state_batch.noise_pred = components.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| return_dict=False, | |
| **cond_kwargs, | |
| **shared_kwargs, | |
| )[0] | |
| components.guider.cleanup_models(components.transformer) | |
| noise_pred = components.guider(guider_state)[0] | |
| # Scheduler step with DMD args | |
| latents = components.scheduler.step( | |
| noise_pred, | |
| t, | |
| latents, | |
| generator=block_state.generator, | |
| return_dict=False, | |
| cur_sampling_step=i, | |
| dmd_noisy_tensor=start_point_list[i_s], | |
| dmd_sigmas=components.scheduler.sigmas, | |
| dmd_timesteps=components.scheduler.timesteps, | |
| all_timesteps=timesteps, | |
| )[0] | |
| if i == len(timesteps) - 1 or ( | |
| (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 | |
| ): | |
| progress_bar.update() | |
| block_state.latents = latents | |
| return components, block_state | |
| class HeliosPyramidChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """T2V pyramid chunk denoising: history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" | |
| block_classes = [ | |
| HeliosChunkHistorySliceStep, | |
| HeliosPyramidChunkNoiseGenStep, | |
| HeliosPyramidChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "T2V pyramid chunk denoise step that iterates over temporal chunks.\n" | |
| "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" | |
| "Denoising starts at the smallest resolution and progressively upsamples." | |
| ) | |
| class HeliosPyramidI2VChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """I2V pyramid chunk denoising: I2V history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" | |
| block_classes = [ | |
| HeliosI2VChunkHistorySliceStep, | |
| HeliosPyramidChunkNoiseGenStep, | |
| HeliosPyramidChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "I2V pyramid chunk denoise step that iterates over temporal chunks.\n" | |
| "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" | |
| "Denoising starts at the smallest resolution and progressively upsamples." | |
| ) | |
| class HeliosPyramidDistilledChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """T2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" | |
| block_classes = [ | |
| HeliosChunkHistorySliceStep, | |
| HeliosPyramidChunkNoiseGenStep, | |
| HeliosPyramidDistilledChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "T2V distilled pyramid chunk denoise step with DMD scheduler.\n" | |
| "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." | |
| ) | |
| class HeliosPyramidDistilledI2VChunkDenoiseStep(HeliosChunkLoopWrapper): | |
| """I2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" | |
| block_classes = [ | |
| HeliosI2VChunkHistorySliceStep, | |
| HeliosPyramidChunkNoiseGenStep, | |
| HeliosPyramidDistilledChunkDenoiseInner, | |
| HeliosChunkUpdateStep, | |
| ] | |
| block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] | |
| def description(self) -> str: | |
| return ( | |
| "I2V distilled pyramid chunk denoise step with DMD scheduler.\n" | |
| "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." | |
| ) | |