# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """Denoising block for WorldEngine modular pipeline.""" from typing import List import torch from diffusers.utils import logging from diffusers.modular_pipelines import ( ModularPipelineBlocks, ModularPipeline, PipelineState, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, InputParam, OutputParam, ) from diffusers import AutoModel logger = logging.get_logger(__name__) class WorldEngineDenoiseLoop(ModularPipelineBlocks): """Denoises latents using rectified flow and updates KV cache.""" model_name = "world_engine" @property def expected_components(self) -> List[ComponentSpec]: return [ComponentSpec("transformer", AutoModel)] @property def description(self) -> str: return ( "Denoises latents using rectified flow (x = x + dsigma * v) " "and updates KV cache for autoregressive generation." ) @property def inputs(self) -> List[InputParam]: return [ InputParam( "scheduler_sigmas", required=True, type_hint=torch.Tensor, description="Scheduler sigmas for denoising", ), InputParam( "latents", required=True, type_hint=torch.Tensor, description="Initial noisy latents [1, 1, C, H, W]", ), InputParam( "kv_cache", required=True, description="KV cache for transformer attention", ), InputParam( "frame_timestamp", required=True, type_hint=torch.Tensor, description="Current frame timestamp", ), InputParam( "prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for conditioning", ), InputParam( "prompt_pad_mask", type_hint=torch.Tensor, description="Padding mask for prompt embeddings", ), InputParam( "button_tensor", required=True, type_hint=torch.Tensor, description="One-hot encoded button tensor", ), InputParam( "mouse_tensor", required=True, type_hint=torch.Tensor, description="Mouse velocity tensor", ), InputParam( "scroll_tensor", required=True, type_hint=torch.Tensor, description="Scroll wheel sign tensor", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="Denoised latents", ), ] @staticmethod def _denoise_pass( transformer, x, sigmas, frame_timestamp, prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache, ): """Denoising loop using rectified flow.""" kv_cache.set_frozen(True) sigma = x.new_empty((x.size(0), x.size(1))) for step_sig, step_dsig in zip(sigmas, sigmas.diff()): v = transformer( x=x, sigma=sigma.fill_(step_sig), frame_timestamp=frame_timestamp, prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask, mouse=mouse, button=button, scroll=scroll, kv_cache=kv_cache, ) x = x + step_dsig * v return x @staticmethod def _cache_pass( transformer, x, frame_timestamp, prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache, ): """Cache pass to persist frame for next generation.""" kv_cache.set_frozen(False) transformer( x=x, sigma=x.new_zeros((x.size(0), x.size(1))), frame_timestamp=frame_timestamp, prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask, mouse=mouse, button=button, scroll=scroll, kv_cache=kv_cache, ) @torch.inference_mode() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) block_state.latents = self._denoise_pass( components.transformer, block_state.latents, block_state.scheduler_sigmas, block_state.frame_timestamp, block_state.prompt_embeds, block_state.prompt_pad_mask, block_state.mouse_tensor, block_state.button_tensor, block_state.scroll_tensor, block_state.kv_cache, ).clone() self._cache_pass( components.transformer, block_state.latents, block_state.frame_timestamp, block_state.prompt_embeds, block_state.prompt_pad_mask, block_state.mouse_tensor, block_state.button_tensor, block_state.scroll_tensor, block_state.kv_cache, ) block_state.frame_timestamp.add_(1) self.set_block_state(state, block_state) return components, state