| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Before-denoise blocks for WorldEngine modular pipeline.""" |
| |
|
| | from typing import List, Optional, Union |
| |
|
| | import PIL.Image |
| | import torch |
| | from torch import nn, Tensor |
| | from tensordict import TensorDict |
| | from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask |
| |
|
| | from diffusers.configuration_utils import FrozenDict |
| | from diffusers.image_processor import VaeImageProcessor |
| | from diffusers.utils import logging |
| | from diffusers.utils.torch_utils import randn_tensor |
| | from diffusers.modular_pipelines import ( |
| | ModularPipelineBlocks, |
| | ModularPipeline, |
| | PipelineState, |
| | SequentialPipelineBlocks, |
| | ) |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| | ComponentSpec, |
| | ConfigSpec, |
| | InputParam, |
| | OutputParam, |
| | ) |
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: |
| | """ |
| | Create a block mask for flex_attention. |
| | |
| | Args: |
| | T: Q length for this frame |
| | L: KV capacity == written.numel() |
| | written: [L] bool, True where there is valid KV data |
| | """ |
| | BS = _DEFAULT_SPARSE_BLOCK_SIZE |
| | KV_blocks = (L + BS - 1) // BS |
| | Q_blocks = (T + BS - 1) // BS |
| |
|
| | |
| | written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view( |
| | KV_blocks, BS |
| | ) |
| |
|
| | |
| | block_any = written_blocks.any(-1) |
| | block_all = written_blocks.all(-1) |
| |
|
| | |
| | nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) |
| | full_bm = block_all[None, :].expand_as(nonzero_bm) |
| | partial_bm = nonzero_bm & ~full_bm |
| |
|
| | def dense_to_ordered(dense_mask: torch.Tensor): |
| | |
| | |
| | num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) |
| | indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to( |
| | torch.int32 |
| | ) |
| | return num_blocks[None, None].contiguous(), indices[None, None].contiguous() |
| |
|
| | |
| | kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) |
| |
|
| | |
| | full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) |
| |
|
| | def mask_mod(b, h, q, kv): |
| | return written[kv] |
| |
|
| | bm = BlockMask.from_kv_blocks( |
| | kv_num_blocks, |
| | kv_indices, |
| | full_kv_num_blocks, |
| | full_kv_indices, |
| | BLOCK_SIZE=BS, |
| | mask_mod=mask_mod, |
| | seq_lengths=(T, L), |
| | compute_q_blocks=False, |
| | ) |
| |
|
| | return bm |
| |
|
| |
|
| | class LayerKVCache(nn.Module): |
| | """ |
| | Ring-buffer KV cache with fixed capacity L (tokens) for history plus |
| | one extra frame (tokens_per_frame) at the tail holding the current frame. |
| | """ |
| |
|
| | def __init__( |
| | self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 |
| | ): |
| | super().__init__() |
| | self.tpf = tokens_per_frame |
| | self.L = L |
| | |
| | self.capacity = L + self.tpf |
| | self.pinned_dilation = pinned_dilation |
| | self.num_buckets = (L // self.tpf) // self.pinned_dilation |
| | assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 |
| |
|
| | |
| | self.kv = nn.Buffer( |
| | torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), |
| | persistent=False, |
| | ) |
| |
|
| | |
| | |
| | written = torch.zeros(self.capacity, dtype=torch.bool) |
| | written[L:] = True |
| | self.written = nn.Buffer(written, persistent=False) |
| |
|
| | |
| | |
| | |
| | self.frame_offsets = nn.Buffer( |
| | torch.arange(self.tpf, dtype=torch.long), persistent=False |
| | ) |
| | self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) |
| |
|
| | def reset(self): |
| | self.kv.zero_() |
| | self.written.zero_() |
| | self.written[self.L :].fill_(True) |
| |
|
| | def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): |
| | """ |
| | Args: |
| | kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) |
| | pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1) |
| | """ |
| | T = self.tpf |
| | t_pos = pos_ids["t_pos"] |
| |
|
| | if not torch.compiler.is_compiling(): |
| | torch._check( |
| | kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" |
| | ) |
| | torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") |
| | torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") |
| | torch._check( |
| | self.L % self.tpf == 0, |
| | f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", |
| | ) |
| | torch._check( |
| | self.kv.size(3) == self.capacity, |
| | "KV buffer has unexpected length (expected L + tokens_per_frame)", |
| | ) |
| | torch._check( |
| | (t_pos >= 0).all().item(), |
| | "t_pos must be non-negative during inference", |
| | ) |
| | torch._check( |
| | ((t_pos == t_pos[:, :1]).all()).item(), |
| | "t_pos must be constant within frame", |
| | ) |
| |
|
| | frame_t = t_pos[0, 0] |
| |
|
| | |
| | bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation |
| | slot = bucket % self.num_buckets |
| | base = slot * T |
| |
|
| | |
| | ring_idx = self.frame_offsets + base |
| |
|
| | |
| | |
| | self.kv.index_copy_(3, self.current_idx, kv) |
| |
|
| | write_step = frame_t.remainder(self.pinned_dilation) == 0 |
| | mask_written = self.written.clone() |
| | mask_written[ring_idx] = mask_written[ring_idx] & ~write_step |
| | bm = make_block_mask(T, self.capacity, mask_written) |
| |
|
| | |
| | if not is_frozen: |
| | |
| | dst = torch.where(write_step, ring_idx, self.current_idx) |
| | self.kv.index_copy_(3, dst, kv) |
| | self.written[dst] = True |
| |
|
| | k, v = self.kv.unbind(0) |
| | return k, v, bm |
| |
|
| |
|
| | class StaticKVCache(nn.Module): |
| | """Static KV cache with per-layer configuration for local/global attention.""" |
| |
|
| | def __init__(self, config, batch_size, dtype): |
| | super().__init__() |
| |
|
| | self.tpf = config.tokens_per_frame |
| |
|
| | local_L = config.local_window * self.tpf |
| | global_L = config.global_window * self.tpf |
| |
|
| | period = config.global_attn_period |
| | off = getattr(config, "global_attn_offset", 0) % period |
| | self.layers = nn.ModuleList( |
| | [ |
| | LayerKVCache( |
| | batch_size, |
| | getattr(config, "n_kv_heads", config.n_heads), |
| | global_L if ((layer_idx - off) % period == 0) else local_L, |
| | config.d_model // config.n_heads, |
| | dtype, |
| | self.tpf, |
| | ( |
| | config.global_pinned_dilation |
| | if ((layer_idx - off) % period == 0) |
| | else 1 |
| | ), |
| | ) |
| | for layer_idx in range(config.n_layers) |
| | ] |
| | ) |
| |
|
| | self._is_frozen = True |
| |
|
| | def reset(self): |
| | for layer in self.layers: |
| | layer.reset() |
| | self._is_frozen = True |
| |
|
| | def set_frozen(self, is_frozen: bool): |
| | self._is_frozen = is_frozen |
| |
|
| | def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): |
| | kv = torch.stack([k, v], dim=0) |
| | return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) |
| |
|
| |
|
| | class WorldEngineSetTimestepsStep(ModularPipelineBlocks): |
| | """Sets up the scheduler sigmas for rectified flow denoising.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Sets up scheduler sigmas for rectified flow denoising" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "scheduler_sigmas", |
| | type_hint=List[float], |
| | description="Custom scheduler sigmas (overrides config)", |
| | ), |
| | InputParam( |
| | "frame_timestamp", |
| | type_hint=torch.Tensor, |
| | description="Current frame timestamp", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "scheduler_sigmas", |
| | type_hint=torch.Tensor, |
| | description="Tensor of scheduler sigmas for denoising", |
| | ), |
| | OutputParam( |
| | "frame_timestamp", |
| | type_hint=torch.Tensor, |
| | description="Current frame timestamp", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | device = components._execution_device |
| | dtype = components.transformer.dtype |
| |
|
| | |
| | sigmas = block_state.scheduler_sigmas |
| | if sigmas is None: |
| | sigmas = components.config.scheduler_sigmas |
| | block_state.scheduler_sigmas = torch.tensor( |
| | sigmas, device=device, dtype=dtype |
| | ) |
| |
|
| | frame_ts = block_state.frame_timestamp |
| | if frame_ts is None: |
| | frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) |
| | elif isinstance(frame_ts, int): |
| | frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) |
| |
|
| | block_state.frame_timestamp = frame_ts |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): |
| | """Initializes or reuses the KV cache for autoregressive generation.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Initializes or reuses KV cache for autoregressive frame generation" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "kv_cache", |
| | type_hint=Optional[StaticKVCache], |
| | description="Existing KV cache (will be reused if provided)", |
| | ), |
| | InputParam( |
| | "reset_cache", |
| | type_hint=bool, |
| | default=False, |
| | description="If True, reset the KV cache even if one exists", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "kv_cache", |
| | type_hint=StaticKVCache, |
| | description="KV cache for transformer attention", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | device = components._execution_device |
| | dtype = components.transformer.dtype |
| |
|
| | |
| | if block_state.kv_cache is None: |
| | block_state.kv_cache = StaticKVCache( |
| | components.transformer.config, |
| | batch_size=1, |
| | dtype=dtype, |
| | ).to(device) |
| | elif block_state.reset_cache: |
| | block_state.kv_cache.reset() |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): |
| | """Prepares latents for frame generation, optionally encoding an input image.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Prepares latents for frame generation. If an image is provided on the " |
| | "first frame, encodes it and caches it as context. Always creates fresh " |
| | "random noise for the actual denoising." |
| | ) |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec( |
| | "image_processor", |
| | VaeImageProcessor, |
| | config=FrozenDict( |
| | { |
| | "vae_scale_factor": 16, |
| | "do_normalize": False, |
| | "do_convert_rgb": False, |
| | } |
| | ), |
| | default_creation_method="from_config", |
| | ), |
| | ] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ConfigSpec]: |
| | return [ |
| | ConfigSpec("channels", 16), |
| | ConfigSpec("height", 16), |
| | ConfigSpec("width", 16), |
| | ConfigSpec("patch", [2, 2]), |
| | ConfigSpec("vae_scale_factor", 16), |
| | ] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "image", |
| | type_hint=Union[PIL.Image.Image, torch.Tensor], |
| | description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", |
| | ), |
| | InputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", |
| | ), |
| | InputParam( |
| | "use_random_latents", |
| | type_hint=bool, |
| | default=True, |
| | description="If True, always generate fresh random latents. If False, use provided latents.", |
| | ), |
| | InputParam( |
| | "kv_cache", |
| | description="KV cache to update", |
| | ), |
| | InputParam( |
| | "frame_timestamp", |
| | type_hint=torch.Tensor, |
| | description="Current frame timestamp", |
| | ), |
| | InputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="Prompt embeddings for cache pass", |
| | ), |
| | InputParam( |
| | "prompt_pad_mask", |
| | type_hint=torch.Tensor, |
| | description="Prompt padding mask", |
| | ), |
| | InputParam( |
| | "button_tensor", |
| | type_hint=torch.Tensor, |
| | description="Button tensor for cache pass", |
| | ), |
| | InputParam( |
| | "mouse_tensor", |
| | type_hint=torch.Tensor, |
| | description="Mouse tensor for cache pass", |
| | ), |
| | InputParam( |
| | "scroll_tensor", |
| | type_hint=torch.Tensor, |
| | description="Scroll tensor for cache pass", |
| | ), |
| | InputParam( |
| | "generator", |
| | type_hint=torch.Generator, |
| | default=None, |
| | description="torch Generator for deterministic output", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "latents", |
| | type_hint=torch.Tensor, |
| | description="Latent tensor for denoising [1, 1, C, H, W]", |
| | ), |
| | ] |
| |
|
| | @staticmethod |
| | def _cache_pass( |
| | transformer, |
| | x, |
| | frame_timestamp, |
| | prompt_emb, |
| | prompt_pad_mask, |
| | mouse, |
| | button, |
| | scroll, |
| | kv_cache, |
| | ): |
| | """Cache pass to persist frame in KV cache.""" |
| | kv_cache.set_frozen(False) |
| | transformer( |
| | x=x, |
| | sigma=x.new_zeros((x.size(0), x.size(1))), |
| | frame_timestamp=frame_timestamp, |
| | prompt_emb=prompt_emb, |
| | prompt_pad_mask=prompt_pad_mask, |
| | mouse=mouse, |
| | button=button, |
| | scroll=scroll, |
| | kv_cache=kv_cache, |
| | ) |
| |
|
| | @torch.inference_mode() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | device = components._execution_device |
| | dtype = components.transformer.dtype |
| |
|
| | |
| | channels = components.config.channels |
| | height = components.config.height |
| | width = components.config.width |
| | patch = components.config.patch |
| |
|
| | pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) |
| | shape = ( |
| | 1, |
| | 1, |
| | channels, |
| | components.config.vae_scale_factor * pH, |
| | components.config.vae_scale_factor * pW, |
| | ) |
| |
|
| | if block_state.image is not None: |
| | image = block_state.image |
| | |
| | image = components.image_processor.preprocess( |
| | image, |
| | height=height, |
| | width=width, |
| | ) |
| | |
| | image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) |
| |
|
| | assert image.dtype == torch.uint8, ( |
| | f"Expected uint8 image, got {image.dtype}" |
| | ) |
| |
|
| | latents = components.vae.encode(image) |
| | latents = latents.unsqueeze(1) |
| |
|
| | |
| | self._cache_pass( |
| | components.transformer, |
| | latents, |
| | block_state.frame_timestamp, |
| | block_state.prompt_embeds, |
| | block_state.prompt_pad_mask, |
| | block_state.mouse_tensor, |
| | block_state.button_tensor, |
| | block_state.scroll_tensor, |
| | block_state.kv_cache, |
| | ) |
| | block_state.frame_timestamp.add_(1) |
| |
|
| | |
| | if block_state.use_random_latents or block_state.latents is None: |
| | block_state.latents = torch.randn( |
| | shape, device=device, dtype=torch.bfloat16 |
| | ) |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): |
| | """Sequential pipeline that prepares all inputs for denoising.""" |
| |
|
| | block_classes = [ |
| | WorldEngineSetTimestepsStep, |
| | WorldEngineSetupKVCacheStep, |
| | WorldEnginePrepareLatentsStep, |
| | ] |
| | block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Before denoise step that prepares inputs for denoising:\n" |
| | " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" |
| | " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" |
| | " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" |
| | ) |
| |
|