# Copyright (C) 2025 Hugging Face Team and Overworld # # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program. If not, see . """Before-denoise blocks for WorldEngine modular pipeline.""" from typing import List, Optional, Union import PIL.Image import torch from torch import nn, Tensor from tensordict import TensorDict from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask from diffusers.configuration_utils import FrozenDict from diffusers.image_processor import VaeImageProcessor from diffusers.utils import logging from diffusers.utils.torch_utils import randn_tensor from diffusers.modular_pipelines import ( ModularPipelineBlocks, ModularPipeline, PipelineState, SequentialPipelineBlocks, ) from diffusers.modular_pipelines.modular_pipeline_utils import ( ComponentSpec, ConfigSpec, InputParam, OutputParam, ) logger = logging.get_logger(__name__) def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: """ Create a block mask for flex_attention. Args: T: Q length for this frame L: KV capacity == written.numel() written: [L] bool, True where there is valid KV data """ BS = _DEFAULT_SPARSE_BLOCK_SIZE KV_blocks = (L + BS - 1) // BS Q_blocks = (T + BS - 1) // BS # [KV_blocks, BS] written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view( KV_blocks, BS ) # Block-level occupancy block_any = written_blocks.any(-1) # block has at least one written token block_all = written_blocks.all(-1) # block is fully written # Every Q-block sees the same KV-block pattern nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks] full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks] partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks] def dense_to_ordered(dense_mask: torch.Tensor): # dense_mask: [Q_blocks, KV_blocks] bool # returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks] num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks] indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to( torch.int32 ) return num_blocks[None, None].contiguous(), indices[None, None].contiguous() # Partial blocks (need mask_mod) kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) # Full blocks (mask_mod can be skipped entirely) full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) def mask_mod(b, h, q, kv): return written[kv] bm = BlockMask.from_kv_blocks( kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, BLOCK_SIZE=BS, mask_mod=mask_mod, seq_lengths=(T, L), compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path ) return bm class LayerKVCache(nn.Module): """ Ring-buffer KV cache with fixed capacity L (tokens) for history plus one extra frame (tokens_per_frame) at the tail holding the current frame. """ def __init__( self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 ): super().__init__() self.tpf = tokens_per_frame self.L = L # total KV capacity: ring (L) + tail frame (tpf) self.capacity = L + self.tpf self.pinned_dilation = pinned_dilation self.num_buckets = (L // self.tpf) // self.pinned_dilation assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 # KV buffer: [2, B, H, capacity, Dh] self.kv = nn.Buffer( torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), persistent=False, ) # which slots have ever been written # tail slice [L, L+tpf) always holds the current frame and is considered written written = torch.zeros(self.capacity, dtype=torch.bool) written[L:] = True self.written = nn.Buffer(written, persistent=False) # Precompute indices: # frame_offsets: [0, 1, ..., tpf-1] (for ring indexing) # current_idx: [L, L+1, ..., L+tpf-1] (tail slice) self.frame_offsets = nn.Buffer( torch.arange(self.tpf, dtype=torch.long), persistent=False ) self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) def reset(self): self.kv.zero_() self.written.zero_() self.written[self.L :].fill_(True) def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): """ Args: kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1) """ T = self.tpf t_pos = pos_ids["t_pos"] if not torch.compiler.is_compiling(): torch._check( kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" ) torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") torch._check( self.L % self.tpf == 0, f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", ) torch._check( self.kv.size(3) == self.capacity, "KV buffer has unexpected length (expected L + tokens_per_frame)", ) torch._check( (t_pos >= 0).all().item(), "t_pos must be non-negative during inference", ) torch._check( ((t_pos == t_pos[:, :1]).all()).item(), "t_pos must be constant within frame", ) frame_t = t_pos[0, 0] # map frame_t to a bucket, each bucket owns T contiguous slots bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation slot = bucket % self.num_buckets base = slot * T # indices in the ring for this frame: [T] in [0, L) ring_idx = self.frame_offsets + base # Always write current frame into the tail slice [L, L+T): # this is the "self-attention component" for the current frame. self.kv.index_copy_(3, self.current_idx, kv) write_step = frame_t.remainder(self.pinned_dilation) == 0 mask_written = self.written.clone() mask_written[ring_idx] = mask_written[ring_idx] & ~write_step bm = make_block_mask(T, self.capacity, mask_written) # Persist current frame into the ring for future queries when unfrozen. if not is_frozen: # Persist current frame into the ring for future queries. dst = torch.where(write_step, ring_idx, self.current_idx) self.kv.index_copy_(3, dst, kv) self.written[dst] = True k, v = self.kv.unbind(0) return k, v, bm class StaticKVCache(nn.Module): """Static KV cache with per-layer configuration for local/global attention.""" def __init__(self, config, batch_size, dtype): super().__init__() self.tpf = config.tokens_per_frame local_L = config.local_window * self.tpf global_L = config.global_window * self.tpf period = config.global_attn_period off = getattr(config, "global_attn_offset", 0) % period self.layers = nn.ModuleList( [ LayerKVCache( batch_size, getattr(config, "n_kv_heads", config.n_heads), global_L if ((layer_idx - off) % period == 0) else local_L, config.d_model // config.n_heads, dtype, self.tpf, ( config.global_pinned_dilation if ((layer_idx - off) % period == 0) else 1 ), ) for layer_idx in range(config.n_layers) ] ) self._is_frozen = True def reset(self): for layer in self.layers: layer.reset() self._is_frozen = True def set_frozen(self, is_frozen: bool): self._is_frozen = is_frozen def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): kv = torch.stack([k, v], dim=0) return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) class WorldEngineSetTimestepsStep(ModularPipelineBlocks): """Sets up the scheduler sigmas for rectified flow denoising.""" model_name = "world_engine" @property def description(self) -> str: return "Sets up scheduler sigmas for rectified flow denoising" @property def expected_components(self) -> List[ComponentSpec]: return [] @property def expected_configs(self) -> List[ConfigSpec]: return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] @property def inputs(self) -> List[InputParam]: return [ InputParam( "scheduler_sigmas", type_hint=List[float], description="Custom scheduler sigmas (overrides config)", ), InputParam( "frame_timestamp", type_hint=torch.Tensor, description="Current frame timestamp", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "scheduler_sigmas", type_hint=torch.Tensor, description="Tensor of scheduler sigmas for denoising", ), OutputParam( "frame_timestamp", type_hint=torch.Tensor, description="Current frame timestamp", ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device dtype = components.transformer.dtype # Use provided sigmas or get from config sigmas = block_state.scheduler_sigmas if sigmas is None: sigmas = components.config.scheduler_sigmas block_state.scheduler_sigmas = torch.tensor( sigmas, device=device, dtype=dtype ) frame_ts = block_state.frame_timestamp if frame_ts is None: frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) elif isinstance(frame_ts, int): frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) block_state.frame_timestamp = frame_ts self.set_block_state(state, block_state) return components, state class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): """Initializes or reuses the KV cache for autoregressive generation.""" model_name = "world_engine" @property def description(self) -> str: return "Initializes or reuses KV cache for autoregressive frame generation" @property def expected_components(self) -> List[ComponentSpec]: return [] @property def inputs(self) -> List[InputParam]: return [ InputParam( "kv_cache", type_hint=Optional[StaticKVCache], description="Existing KV cache (will be reused if provided)", ), InputParam( "reset_cache", type_hint=bool, default=False, description="If True, reset the KV cache even if one exists", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "kv_cache", type_hint=StaticKVCache, description="KV cache for transformer attention", ), ] @torch.no_grad() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device dtype = components.transformer.dtype # Create or reuse KV cache if block_state.kv_cache is None: block_state.kv_cache = StaticKVCache( components.transformer.config, batch_size=1, dtype=dtype, ).to(device) elif block_state.reset_cache: block_state.kv_cache.reset() self.set_block_state(state, block_state) return components, state class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): """Prepares latents for frame generation, optionally encoding an input image.""" model_name = "world_engine" @property def description(self) -> str: return ( "Prepares latents for frame generation. If an image is provided on the " "first frame, encodes it and caches it as context. Always creates fresh " "random noise for the actual denoising." ) @property def expected_components(self) -> List[ComponentSpec]: return [ ComponentSpec( "image_processor", VaeImageProcessor, config=FrozenDict( { "vae_scale_factor": 16, "do_normalize": False, "do_convert_rgb": False, } ), default_creation_method="from_config", ), ] @property def expected_configs(self) -> List[ConfigSpec]: return [ ConfigSpec("channels", 16), ConfigSpec("height", 16), ConfigSpec("width", 16), ConfigSpec("patch", [2, 2]), ConfigSpec("vae_scale_factor", 16), ] @property def inputs(self) -> List[InputParam]: return [ InputParam( "image", type_hint=Union[PIL.Image.Image, torch.Tensor], description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", ), InputParam( "latents", type_hint=torch.Tensor, description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", ), InputParam( "use_random_latents", type_hint=bool, default=True, description="If True, always generate fresh random latents. If False, use provided latents.", ), InputParam( "kv_cache", description="KV cache to update", ), InputParam( "frame_timestamp", type_hint=torch.Tensor, description="Current frame timestamp", ), InputParam( "prompt_embeds", type_hint=torch.Tensor, description="Prompt embeddings for cache pass", ), InputParam( "prompt_pad_mask", type_hint=torch.Tensor, description="Prompt padding mask", ), InputParam( "button_tensor", type_hint=torch.Tensor, description="Button tensor for cache pass", ), InputParam( "mouse_tensor", type_hint=torch.Tensor, description="Mouse tensor for cache pass", ), InputParam( "scroll_tensor", type_hint=torch.Tensor, description="Scroll tensor for cache pass", ), InputParam( "generator", type_hint=torch.Generator, default=None, description="torch Generator for deterministic output", ), ] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "latents", type_hint=torch.Tensor, description="Latent tensor for denoising [1, 1, C, H, W]", ), ] @staticmethod def _cache_pass( transformer, x, frame_timestamp, prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache, ): """Cache pass to persist frame in KV cache.""" kv_cache.set_frozen(False) transformer( x=x, sigma=x.new_zeros((x.size(0), x.size(1))), frame_timestamp=frame_timestamp, prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask, mouse=mouse, button=button, scroll=scroll, kv_cache=kv_cache, ) @torch.inference_mode() def __call__( self, components: ModularPipeline, state: PipelineState ) -> PipelineState: block_state = self.get_block_state(state) device = components._execution_device dtype = components.transformer.dtype # Get latent shape info channels = components.config.channels height = components.config.height width = components.config.width patch = components.config.patch pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) shape = ( 1, 1, channels, components.config.vae_scale_factor * pH, components.config.vae_scale_factor * pW, ) if block_state.image is not None: image = block_state.image # Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1] image = components.image_processor.preprocess( image, height=height, width=width, ) # Convert to [H, W, 3] uint8 for VAE encoder image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) assert image.dtype == torch.uint8, ( f"Expected uint8 image, got {image.dtype}" ) latents = components.vae.encode(image) latents = latents.unsqueeze(1) # Run cache pass to persist encoded frame self._cache_pass( components.transformer, latents, block_state.frame_timestamp, block_state.prompt_embeds, block_state.prompt_pad_mask, block_state.mouse_tensor, block_state.button_tensor, block_state.scroll_tensor, block_state.kv_cache, ) block_state.frame_timestamp.add_(1) # Generate latents based on use_random_latents flag if block_state.use_random_latents or block_state.latents is None: block_state.latents = torch.randn( shape, device=device, dtype=torch.bfloat16 ) self.set_block_state(state, block_state) return components, state class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): """Sequential pipeline that prepares all inputs for denoising.""" block_classes = [ WorldEngineSetTimestepsStep, WorldEngineSetupKVCacheStep, WorldEnginePrepareLatentsStep, ] block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] @property def description(self) -> str: return ( "Before denoise step that prepares inputs for denoising:\n" " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" )