# Copyright 2025 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import inspect import math import torch import torch.nn.functional as F from tqdm.auto import tqdm from ...configuration_utils import FrozenDict from ...guiders import ClassifierFreeGuidance, ClassifierFreeZeroStarGuidance from ...models import HeliosTransformer3DModel from ...schedulers import HeliosScheduler from ...utils import logging from ...utils.torch_utils import randn_tensor from ..modular_pipeline import ( BlockState, LoopSequentialPipelineBlocks, ModularPipelineBlocks, PipelineState, ) from ..modular_pipeline_utils import ComponentSpec, InputParam, OutputParam from .before_denoise import calculate_shift from .modular_pipeline import HeliosModularPipeline logger = logging.get_logger(__name__) # pylint: disable=invalid-name def sample_block_noise( batch_size, channel, num_frames, height, width, gamma, patch_size=(1, 2, 2), device=None, generator=None, ): """Generate spatially-correlated block noise for pyramid upsampling correction. Uses a multivariate normal distribution with covariance based on `gamma` to produce noise with block structure, matching the upsampling artifacts that need correction. """ # NOTE: A generator must be provided to ensure correct and reproducible results. # Creating a default generator here is a fallback only — without a fixed seed, # the output will be non-deterministic and may produce incorrect results in CP context. if generator is None: generator = torch.Generator(device=device) elif isinstance(generator, list): generator = generator[0] _, ph, pw = patch_size block_size = ph * pw cov = ( torch.eye(block_size, device=device) * (1 + gamma) - torch.ones(block_size, block_size, device=device) * gamma ) cov += torch.eye(block_size, device=device) * 1e-8 cov = cov.float() # Upcast to fp32 for numerical stability — cholesky is unreliable in fp16/bf16. L = torch.linalg.cholesky(cov) block_number = batch_size * channel * num_frames * (height // ph) * (width // pw) z = torch.randn(block_number, block_size, device=generator.device, generator=generator).to(device) noise = z @ L.T noise = noise.view(batch_size, channel, num_frames, height // ph, width // pw, ph, pw) noise = noise.permute(0, 1, 2, 3, 5, 4, 6).reshape(batch_size, channel, num_frames, height, width) return noise # ======================================== # Chunk Loop Leaf Blocks # ======================================== class HeliosChunkHistorySliceStep(ModularPipelineBlocks): """Slices history latents into short/mid/long for a T2V chunk. At k==0 with no image_latents, creates a zero prefix. Otherwise uses image_latents (either provided or captured from first chunk by HeliosChunkUpdateStep). """ model_name = "helios" @property def description(self) -> str: return ( "T2V history slice: splits history into long/mid/short. At k==0 with no image_latents, " "creates a zero prefix; otherwise uses image_latents as prefix for short history." ) @property def inputs(self) -> list[InputParam]: return [ InputParam( "keep_first_frame", default=True, type_hint=bool, description="Whether to keep the first frame as a prefix in history.", ), InputParam( "history_sizes", required=True, type_hint=list, description="Sizes of long/mid/short history buffers for temporal context.", ), InputParam( "history_latents", required=True, type_hint=torch.Tensor, description="Accumulated history latents from previous chunks.", ), InputParam("latent_shape", required=True, type_hint=tuple), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): keep_first_frame = block_state.keep_first_frame history_sizes = block_state.history_sizes image_latents = block_state.image_latents device = components._execution_device batch_size, num_channels_latents, _, h_latent, w_latent = block_state.latent_shape if keep_first_frame: latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) if image_latents is None and k == 0: latents_prefix = torch.zeros( batch_size, num_channels_latents, 1, h_latent, w_latent, device=device, dtype=torch.float32, ) else: latents_prefix = image_latents latents_history_short = torch.cat([latents_prefix, latents_history_1x], dim=2) else: latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) block_state.latents_history_short = latents_history_short block_state.latents_history_mid = latents_history_mid block_state.latents_history_long = latents_history_long return components, block_state class HeliosI2VChunkHistorySliceStep(ModularPipelineBlocks): """Slices history latents into short/mid/long for an I2V chunk. Always uses image_latents as prefix (assumes history pre-seeded with fake_image_latents). """ model_name = "helios" @property def description(self) -> str: return ( "I2V history slice: splits pre-seeded history into long/mid/short, " "always using image_latents as prefix for short history." ) @property def inputs(self) -> list[InputParam]: return [ InputParam( "keep_first_frame", default=True, type_hint=bool, description="Whether to keep the first frame as a prefix in history.", ), InputParam( "history_sizes", required=True, type_hint=list, description="Sizes of long/mid/short history buffers for temporal context.", ), InputParam( "history_latents", required=True, type_hint=torch.Tensor, description="Accumulated history latents from previous chunks.", ), InputParam( "image_latents", required=True, type_hint=torch.Tensor, description="First-frame latents used as prefix for short history.", ), ] @property def intermediate_outputs(self) -> list[OutputParam]: return [] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): keep_first_frame = block_state.keep_first_frame history_sizes = block_state.history_sizes image_latents = block_state.image_latents if keep_first_frame: latents_history_long, latents_history_mid, latents_history_1x = block_state.history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) latents_history_short = torch.cat([image_latents, latents_history_1x], dim=2) else: latents_history_long, latents_history_mid, latents_history_short = block_state.history_latents[ :, :, -sum(history_sizes) : ].split(history_sizes, dim=2) block_state.latents_history_short = latents_history_short block_state.latents_history_mid = latents_history_mid block_state.latents_history_long = latents_history_long return components, block_state class HeliosChunkNoiseGenStep(ModularPipelineBlocks): """Generates noise latents for a chunk using randn_tensor.""" model_name = "helios" @property def description(self) -> str: return "Generates random noise latents at full resolution for a single chunk." @property def inputs(self) -> list[InputParam]: return [ InputParam("latent_shape", required=True, type_hint=tuple), InputParam.template("generator"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): device = components._execution_device block_state.latents = randn_tensor( block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 ) return components, block_state class HeliosPyramidChunkNoiseGenStep(ModularPipelineBlocks): """Generates noise latents and downsamples to smallest pyramid level.""" model_name = "helios-pyramid" @property def description(self) -> str: return ( "Generates random noise at full resolution, then downsamples to the smallest " "pyramid level via bilinear interpolation." ) @property def inputs(self) -> list[InputParam]: return [ InputParam("latent_shape", required=True, type_hint=tuple), InputParam( "pyramid_num_inference_steps_list", default=[10, 10, 10], type_hint=list, description="Number of denoising steps per pyramid stage.", ), InputParam.template("generator"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): device = components._execution_device batch_size, num_channels_latents, num_latent_frames, h_latent, w_latent = block_state.latent_shape latents = randn_tensor( block_state.latent_shape, generator=block_state.generator, device=device, dtype=torch.float32 ) # Downsample to smallest pyramid level h, w = h_latent, w_latent latents = latents.permute(0, 2, 1, 3, 4).reshape(batch_size * num_latent_frames, num_channels_latents, h, w) for _ in range(len(block_state.pyramid_num_inference_steps_list) - 1): h //= 2 w //= 2 latents = F.interpolate(latents, size=(h, w), mode="bilinear") * 2 block_state.latents = latents.reshape(batch_size, num_latent_frames, num_channels_latents, h, w).permute( 0, 2, 1, 3, 4 ) return components, block_state class HeliosChunkSchedulerResetStep(ModularPipelineBlocks): """Resets the scheduler with timesteps for a single chunk.""" model_name = "helios" @property def description(self) -> str: return "Resets the scheduler with the correct timesteps and shift parameter (mu) for this chunk." @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("scheduler", HeliosScheduler), ] @property def inputs(self) -> list[InputParam]: return [ InputParam("mu", required=True, type_hint=float), InputParam.template("sigmas", required=True), InputParam.template("num_inference_steps"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): device = components._execution_device components.scheduler.set_timesteps( block_state.num_inference_steps, device=device, sigmas=block_state.sigmas, mu=block_state.mu ) block_state.timesteps = components.scheduler.timesteps return components, block_state # ======================================== # Inner Denoising Blocks # ======================================== class HeliosChunkDenoiseInner(ModularPipelineBlocks): """Inner timestep loop for denoising a single chunk, using guider for guidance.""" model_name = "helios" @property def description(self) -> str: return ( "Inner denoising loop that iterates over timesteps for a single chunk. " "Uses the guider to manage conditional/unconditional forward passes with cache_context, " "applies guidance, and runs scheduler step." ) @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("transformer", HeliosTransformer3DModel), ComponentSpec("scheduler", HeliosScheduler), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 5.0}), default_creation_method="from_config", ), ] @property def inputs(self) -> list[InputParam]: return [ InputParam.template("latents"), InputParam.template("timesteps"), InputParam("prompt_embeds", type_hint=torch.Tensor), InputParam("negative_prompt_embeds", type_hint=torch.Tensor), InputParam.template("denoiser_input_fields"), InputParam.template("num_inference_steps"), InputParam.template("attention_kwargs"), InputParam.template("generator"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): latents = block_state.latents timesteps = block_state.timesteps num_inference_steps = block_state.num_inference_steps transformer_dtype = components.transformer.dtype num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order # Guider inputs: only encoder_hidden_states differs between cond/uncond guider_inputs = { "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), } # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) shared_kwargs = {} for field_name, field_value in block_state.denoiser_input_fields.items(): if field_name in transformer_args and field_name not in guider_inputs: shared_kwargs[field_name] = field_value # Add loop-internal history latents with dtype casting shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) shared_kwargs["attention_kwargs"] = block_state.attention_kwargs with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): timestep = t.expand(latents.shape[0]).to(torch.int64) latent_model_input = latents.to(transformer_dtype) components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): guider_state_batch.noise_pred = components.transformer( hidden_states=latent_model_input, timestep=timestep, return_dict=False, **cond_kwargs, **shared_kwargs, )[0] components.guider.cleanup_models(components.transformer) noise_pred = components.guider(guider_state)[0] # Scheduler step latents = components.scheduler.step( noise_pred, t, latents, generator=block_state.generator, return_dict=False, )[0] if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 ): progress_bar.update() block_state.latents = latents return components, block_state class HeliosPyramidChunkDenoiseInner(ModularPipelineBlocks): """Nested pyramid stage loop with inner timestep denoising. For each pyramid stage (small -> full resolution): 1. Upsample latents + block noise correction (stages > 0) 2. Compute mu from current resolution, set scheduler timesteps 3. Run timestep denoising loop (same logic as HeliosChunkDenoiseInner) """ model_name = "helios-pyramid" @property def description(self) -> str: return ( "Pyramid denoising inner block: loops over pyramid stages from smallest to full resolution. " "Each stage upsamples latents (with block noise correction), recomputes scheduler parameters, " "and runs the timestep denoising loop." ) @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("transformer", HeliosTransformer3DModel), ComponentSpec("scheduler", HeliosScheduler), ComponentSpec( "guider", ClassifierFreeZeroStarGuidance, config=FrozenDict({"guidance_scale": 5.0, "zero_init_steps": 2}), default_creation_method="from_config", ), ] @property def inputs(self) -> list[InputParam]: return [ InputParam.template("latents"), InputParam("prompt_embeds", type_hint=torch.Tensor), InputParam("negative_prompt_embeds", type_hint=torch.Tensor), InputParam.template("denoiser_input_fields"), InputParam( "pyramid_num_inference_steps_list", default=[10, 10, 10], type_hint=list, description="Number of denoising steps per pyramid stage.", ), InputParam.template("attention_kwargs"), InputParam.template("generator"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): device = components._execution_device transformer_dtype = components.transformer.dtype latents = block_state.latents pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) # Guider inputs: only encoder_hidden_states differs between cond/uncond guider_inputs = { "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), } # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) shared_kwargs = {} for field_name, field_value in block_state.denoiser_input_fields.items(): if field_name in transformer_args and field_name not in guider_inputs: shared_kwargs[field_name] = field_value # Add loop-internal history latents with dtype casting shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) shared_kwargs["attention_kwargs"] = block_state.attention_kwargs # Save original zero_init_steps if the guider supports it (e.g. ClassifierFreeZeroStarGuidance). # Helios only applies zero init in pyramid stage 0 (lowest resolution), so we disable it # for subsequent stages by temporarily setting zero_init_steps=0. orig_zero_init_steps = getattr(components.guider, "zero_init_steps", None) for i_s in range(pyramid_num_stages): # --- Stage setup --- # Disable zero init for stages > 0 (only stage 0 should have zero init) if orig_zero_init_steps is not None and i_s > 0: components.guider.zero_init_steps = 0 # a. Compute mu from current resolution (before upsample, matching standard pipeline) patch_size = components.transformer.config.patch_size image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( patch_size[0] * patch_size[1] * patch_size[2] ) mu = calculate_shift( image_seq_len, components.scheduler.config.get("base_image_seq_len", 256), components.scheduler.config.get("max_image_seq_len", 4096), components.scheduler.config.get("base_shift", 0.5), components.scheduler.config.get("max_shift", 1.15), ) # b. Set scheduler timesteps for this stage num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] components.scheduler.set_timesteps( num_inference_steps, i_s, device=device, mu=mu, ) timesteps = components.scheduler.timesteps # c. Upsample + block noise correction for stages > 0 if i_s > 0: batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape new_h = current_h * 2 new_w = current_w * 2 latents = latents.permute(0, 2, 1, 3, 4).reshape( batch_size * num_frames, num_channels_latents, current_h, current_w ) latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( 0, 2, 1, 3, 4 ) # Block noise correction ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] gamma = components.scheduler.config.gamma alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) batch_size, num_channels_latents, num_frames, h, w = latents.shape noise = sample_block_noise( batch_size, num_channels_latents, num_frames, h, w, gamma, patch_size, device=device, generator=block_state.generator, ) noise = noise.to(dtype=transformer_dtype) latents = alpha * latents + beta * noise # --- Timestep denoising loop --- num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): timestep = t.expand(latents.shape[0]).to(torch.int64) latent_model_input = latents.to(transformer_dtype) components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {kk: getattr(guider_state_batch, kk) for kk in guider_inputs.keys()} context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): guider_state_batch.noise_pred = components.transformer( hidden_states=latent_model_input, timestep=timestep, return_dict=False, **cond_kwargs, **shared_kwargs, )[0] components.guider.cleanup_models(components.transformer) noise_pred = components.guider(guider_state)[0] # Scheduler step latents = components.scheduler.step( noise_pred, t, latents, generator=block_state.generator, return_dict=False, )[0] if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 ): progress_bar.update() # Restore original zero_init_steps if orig_zero_init_steps is not None: components.guider.zero_init_steps = orig_zero_init_steps block_state.latents = latents return components, block_state # ======================================== # Post-Denoise Update # ======================================== class HeliosChunkUpdateStep(ModularPipelineBlocks): """Updates chunk collection and history after denoising a single chunk.""" model_name = "helios" @property def description(self) -> str: return ( "Post-denoising update step: appends the denoised latents to the chunk list, " "captures image_latents from the first chunk if needed, and extends history_latents." ) @property def expected_components(self) -> list[ComponentSpec]: return [] @property def inputs(self) -> list[InputParam]: return [ InputParam("latents", type_hint=torch.Tensor), InputParam("history_latents", type_hint=torch.Tensor), InputParam("keep_first_frame", default=True, type_hint=bool), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): # e. Collect denoised latents for this chunk block_state.latent_chunks.append(block_state.latents) # f. Update history if block_state.keep_first_frame and k == 0 and block_state.image_latents is None: block_state.image_latents = block_state.latents[:, :, 0:1, :, :] block_state.history_latents = torch.cat([block_state.history_latents, block_state.latents], dim=2) return components, block_state # ======================================== # Chunk Loop Wrapper # ======================================== class HeliosChunkLoopWrapper(LoopSequentialPipelineBlocks): """Outer chunk loop that iterates over temporal chunks. History indices, scheduler params, and history state are prepared by HeliosPrepareHistoryStep and HeliosSetTimestepsStep before this block runs. Sub-blocks handle per-chunk preparation, denoising, and history updates. """ model_name = "helios" @property def description(self) -> str: return ( "Pipeline block that iterates over temporal chunks for progressive video generation. " "At each chunk iteration, it runs sub-blocks for preparation, denoising, and history updates." ) @property def loop_inputs(self) -> list[InputParam]: return [ InputParam("num_latent_chunk", required=True, type_hint=int), ] @property def loop_intermediate_outputs(self) -> list[OutputParam]: return [ OutputParam("latent_chunks", type_hint=list, description="List of per-chunk denoised latent tensors"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) block_state.latent_chunks = [] if not hasattr(block_state, "image_latents"): block_state.image_latents = None for k in range(block_state.num_latent_chunk): components, block_state = self.loop_step(components, block_state, k=k) self.set_block_state(state, block_state) return components, state # ======================================== # Composed Chunk Denoise Steps # ======================================== class HeliosChunkDenoiseStep(HeliosChunkLoopWrapper): """T2V chunk-based denoising: history slice -> noise gen -> scheduler reset -> denoise -> update.""" block_classes = [ HeliosChunkHistorySliceStep, HeliosChunkNoiseGenStep, HeliosChunkSchedulerResetStep, HeliosChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "T2V chunk denoise step that iterates over temporal chunks.\n" "At each chunk: history_slice -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." ) class HeliosI2VChunkDenoiseStep(HeliosChunkLoopWrapper): """I2V chunk-based denoising: I2V history slice -> noise gen -> scheduler reset -> denoise -> update.""" block_classes = [ HeliosI2VChunkHistorySliceStep, HeliosChunkNoiseGenStep, HeliosChunkSchedulerResetStep, HeliosChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "scheduler_reset", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "I2V chunk denoise step that iterates over temporal chunks.\n" "At each chunk: history_slice (I2V) -> noise_gen -> scheduler_reset -> denoise_inner -> update_chunk." ) class HeliosPyramidDistilledChunkDenoiseInner(ModularPipelineBlocks): """Nested pyramid stage loop with DMD denoising for distilled checkpoints. Same progressive multi-resolution strategy as HeliosPyramidChunkDenoiseInner, but: - Guidance is disabled (guidance_scale=1.0, no unconditional pass) - Supports is_amplify_first_chunk (doubles first chunk's timesteps via scheduler) - Tracks start_point_list and passes DMD-specific args to scheduler.step() """ model_name = "helios-pyramid" @property def description(self) -> str: return ( "Distilled pyramid denoising inner block for DMD checkpoints. Loops over pyramid stages " "from smallest to full resolution with guidance disabled and DMD scheduler support." ) @property def expected_components(self) -> list[ComponentSpec]: return [ ComponentSpec("transformer", HeliosTransformer3DModel), ComponentSpec("scheduler", HeliosScheduler), ComponentSpec( "guider", ClassifierFreeGuidance, config=FrozenDict({"guidance_scale": 1.0}), default_creation_method="from_config", ), ] @property def inputs(self) -> list[InputParam]: return [ InputParam.template("latents"), InputParam("prompt_embeds", type_hint=torch.Tensor), InputParam("negative_prompt_embeds", type_hint=torch.Tensor), InputParam.template("denoiser_input_fields"), InputParam( "pyramid_num_inference_steps_list", default=[2, 2, 2], type_hint=list, description="Number of denoising steps per pyramid stage.", ), InputParam( "is_amplify_first_chunk", default=True, type_hint=bool, description="Whether to double the first chunk's timesteps via the scheduler for amplified generation.", ), InputParam.template("attention_kwargs"), InputParam.template("generator"), ] @torch.no_grad() def __call__(self, components: HeliosModularPipeline, block_state: BlockState, k: int): device = components._execution_device transformer_dtype = components.transformer.dtype latents = block_state.latents pyramid_num_stages = len(block_state.pyramid_num_inference_steps_list) is_first_chunk = k == 0 # Track start points for DMD scheduler start_point_list = [latents] # Guider inputs: only encoder_hidden_states differs between cond/uncond guider_inputs = { "encoder_hidden_states": (block_state.prompt_embeds, block_state.negative_prompt_embeds), } # Build shared kwargs from denoiser_input_fields (excludes guider-managed ones) transformer_args = set(inspect.signature(components.transformer.forward).parameters.keys()) shared_kwargs = {} for field_name, field_value in block_state.denoiser_input_fields.items(): if field_name in transformer_args and field_name not in guider_inputs: shared_kwargs[field_name] = field_value # Add loop-internal history latents with dtype casting shared_kwargs["latents_history_short"] = block_state.latents_history_short.to(transformer_dtype) shared_kwargs["latents_history_mid"] = block_state.latents_history_mid.to(transformer_dtype) shared_kwargs["latents_history_long"] = block_state.latents_history_long.to(transformer_dtype) shared_kwargs["attention_kwargs"] = block_state.attention_kwargs for i_s in range(pyramid_num_stages): # --- Stage setup --- patch_size = components.transformer.config.patch_size # a. Compute mu from current resolution (before upsample, matching standard pipeline) image_seq_len = (latents.shape[-1] * latents.shape[-2] * latents.shape[-3]) // ( patch_size[0] * patch_size[1] * patch_size[2] ) mu = calculate_shift( image_seq_len, components.scheduler.config.get("base_image_seq_len", 256), components.scheduler.config.get("max_image_seq_len", 4096), components.scheduler.config.get("base_shift", 0.5), components.scheduler.config.get("max_shift", 1.15), ) # b. Set scheduler timesteps for this stage (with DMD amplification) num_inference_steps = block_state.pyramid_num_inference_steps_list[i_s] components.scheduler.set_timesteps( num_inference_steps, i_s, device=device, mu=mu, is_amplify_first_chunk=block_state.is_amplify_first_chunk and is_first_chunk, ) timesteps = components.scheduler.timesteps # c. Upsample + block noise correction for stages > 0 if i_s > 0: batch_size, num_channels_latents, num_frames, current_h, current_w = latents.shape new_h = current_h * 2 new_w = current_w * 2 latents = latents.permute(0, 2, 1, 3, 4).reshape( batch_size * num_frames, num_channels_latents, current_h, current_w ) latents = F.interpolate(latents, size=(new_h, new_w), mode="nearest") latents = latents.reshape(batch_size, num_frames, num_channels_latents, new_h, new_w).permute( 0, 2, 1, 3, 4 ) # Block noise correction ori_sigma = 1 - components.scheduler.ori_start_sigmas[i_s] gamma = components.scheduler.config.gamma alpha = 1 / (math.sqrt(1 + (1 / gamma)) * (1 - ori_sigma) + ori_sigma) beta = alpha * (1 - ori_sigma) / math.sqrt(gamma) batch_size, num_channels_latents, num_frames, h, w = latents.shape noise = sample_block_noise( batch_size, num_channels_latents, num_frames, h, w, gamma, patch_size, device=device, generator=block_state.generator, ) noise = noise.to(dtype=transformer_dtype) latents = alpha * latents + beta * noise start_point_list.append(latents) # --- Timestep denoising loop --- num_warmup_steps = len(timesteps) - num_inference_steps * components.scheduler.order with tqdm(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): timestep = t.expand(latents.shape[0]).to(torch.int64) latent_model_input = latents.to(transformer_dtype) components.guider.set_state(step=i, num_inference_steps=num_inference_steps, timestep=t) guider_state = components.guider.prepare_inputs(guider_inputs) for guider_state_batch in guider_state: components.guider.prepare_models(components.transformer) cond_kwargs = {k: getattr(guider_state_batch, k) for k in guider_inputs.keys()} context_name = getattr(guider_state_batch, components.guider._identifier_key) with components.transformer.cache_context(context_name): guider_state_batch.noise_pred = components.transformer( hidden_states=latent_model_input, timestep=timestep, return_dict=False, **cond_kwargs, **shared_kwargs, )[0] components.guider.cleanup_models(components.transformer) noise_pred = components.guider(guider_state)[0] # Scheduler step with DMD args latents = components.scheduler.step( noise_pred, t, latents, generator=block_state.generator, return_dict=False, cur_sampling_step=i, dmd_noisy_tensor=start_point_list[i_s], dmd_sigmas=components.scheduler.sigmas, dmd_timesteps=components.scheduler.timesteps, all_timesteps=timesteps, )[0] if i == len(timesteps) - 1 or ( (i + 1) > num_warmup_steps and (i + 1) % components.scheduler.order == 0 ): progress_bar.update() block_state.latents = latents return components, block_state class HeliosPyramidChunkDenoiseStep(HeliosChunkLoopWrapper): """T2V pyramid chunk denoising: history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" block_classes = [ HeliosChunkHistorySliceStep, HeliosPyramidChunkNoiseGenStep, HeliosPyramidChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "T2V pyramid chunk denoise step that iterates over temporal chunks.\n" "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" "Denoising starts at the smallest resolution and progressively upsamples." ) class HeliosPyramidI2VChunkDenoiseStep(HeliosChunkLoopWrapper): """I2V pyramid chunk denoising: I2V history slice -> pyramid noise gen -> pyramid denoise inner -> update.""" block_classes = [ HeliosI2VChunkHistorySliceStep, HeliosPyramidChunkNoiseGenStep, HeliosPyramidChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "I2V pyramid chunk denoise step that iterates over temporal chunks.\n" "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (pyramid stages) -> update_chunk.\n" "Denoising starts at the smallest resolution and progressively upsamples." ) class HeliosPyramidDistilledChunkDenoiseStep(HeliosChunkLoopWrapper): """T2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" block_classes = [ HeliosChunkHistorySliceStep, HeliosPyramidChunkNoiseGenStep, HeliosPyramidDistilledChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "T2V distilled pyramid chunk denoise step with DMD scheduler.\n" "At each chunk: history_slice -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." ) class HeliosPyramidDistilledI2VChunkDenoiseStep(HeliosChunkLoopWrapper): """I2V distilled pyramid chunk denoising with DMD scheduler and no CFG.""" block_classes = [ HeliosI2VChunkHistorySliceStep, HeliosPyramidChunkNoiseGenStep, HeliosPyramidDistilledChunkDenoiseInner, HeliosChunkUpdateStep, ] block_names = ["history_slice", "noise_gen", "denoise_inner", "update_chunk"] @property def description(self) -> str: return ( "I2V distilled pyramid chunk denoise step with DMD scheduler.\n" "At each chunk: history_slice (I2V) -> noise_gen (pyramid) -> denoise_inner (distilled/DMD) -> update_chunk." )