| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """WorldEngine modular pipeline blocks. |
| |
| All pipeline step classes for text encoding, controller encoding, |
| KV cache setup, latent preparation, denoising, and decoding. |
| """ |
|
|
| import html |
|
|
| import numpy as np |
| import PIL.Image |
| import regex as re |
| import torch |
| from torch import nn, Tensor |
| from tensordict import TensorDict |
| from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask |
| from transformers import AutoTokenizer, UMT5EncoderModel |
|
|
| from diffusers import AutoModel |
| from diffusers.configuration_utils import FrozenDict |
| from diffusers.image_processor import VaeImageProcessor |
| from diffusers.utils import is_ftfy_available, logging |
| from diffusers.modular_pipelines import ( |
| ModularPipelineBlocks, |
| ModularPipeline, |
| PipelineState, |
| SequentialPipelineBlocks, |
| ) |
| from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| ComponentSpec, |
| ConfigSpec, |
| InputParam, |
| InsertableDict, |
| OutputParam, |
| ) |
|
|
| if is_ftfy_available(): |
| import ftfy |
|
|
|
|
| logger = logging.get_logger(__name__) |
|
|
|
|
| |
| |
| |
|
|
| def basic_clean(text): |
| text = ftfy.fix_text(text) |
| text = html.unescape(html.unescape(text)) |
| return text.strip() |
|
|
|
|
| def whitespace_clean(text): |
| text = re.sub(r"\s+", " ", text) |
| text = text.strip() |
| return text |
|
|
|
|
| def prompt_clean(text): |
| text = whitespace_clean(basic_clean(text)) |
| return text |
|
|
|
|
| |
| |
| |
|
|
| def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: |
| """ |
| Create a block mask for flex_attention. |
| |
| T and L must be exact multiples of the sparse block size; written must be |
| block-aligned (each block is either all True or all False). |
| |
| Args: |
| T: Q length for this frame |
| L: KV capacity == written.numel() |
| written: [L] bool, True where there is valid KV data |
| """ |
| BS = _DEFAULT_SPARSE_BLOCK_SIZE |
|
|
| if not torch.compiler.is_compiling(): |
| torch._check(T % BS == 0, f"T ({T}) must be a multiple of block size ({BS})") |
| torch._check(L % BS == 0, f"L ({L}) must be a multiple of block size ({BS})") |
|
|
| Q_blocks = T // BS |
| KV_blocks = L // BS |
|
|
| written_blocks = written.view(KV_blocks, BS) |
| block_any = written_blocks.any(-1) |
|
|
| if not torch.compiler.is_compiling(): |
| assert torch.equal(block_any, written_blocks.all(-1)), "written must be block-aligned" |
|
|
| |
| full_bm = block_any[None, :].expand(Q_blocks, KV_blocks) |
| full_kv_num_blocks = full_bm.sum(dim=-1, dtype=torch.int32)[None, None].contiguous() |
| full_kv_indices = full_bm.argsort(dim=-1, descending=True, stable=True).to(torch.int32)[None, None].contiguous() |
|
|
| |
| kv_num_blocks = torch.zeros((1, 1, Q_blocks), dtype=torch.int32, device=written.device) |
| kv_indices = torch.zeros((1, 1, Q_blocks, KV_blocks), dtype=torch.int32, device=written.device) |
|
|
| return BlockMask.from_kv_blocks( |
| kv_num_blocks, |
| kv_indices, |
| full_kv_num_blocks, |
| full_kv_indices, |
| BLOCK_SIZE=BS, |
| mask_mod=None, |
| seq_lengths=(T, L), |
| compute_q_blocks=False, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class LayerKVCache(nn.Module): |
| """ |
| Ring-buffer KV cache with fixed capacity L (tokens) for history plus |
| one extra frame (tokens_per_frame) at the tail holding the current frame. |
| """ |
|
|
| def __init__( |
| self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 |
| ): |
| super().__init__() |
| self.tpf = tokens_per_frame |
| self.L = L |
| |
| self.capacity = L + self.tpf |
| self.pinned_dilation = pinned_dilation |
| self.num_buckets = (L // self.tpf) // self.pinned_dilation |
| assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 |
|
|
| |
| self.kv = nn.Buffer( |
| torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), |
| persistent=False, |
| ) |
|
|
| |
| |
| written = torch.zeros(self.capacity, dtype=torch.bool) |
| written[L:] = True |
| self.written = nn.Buffer(written, persistent=False) |
|
|
| |
| self._mask_written = nn.Buffer(torch.zeros_like(written), persistent=False) |
|
|
| |
| |
| |
| self.frame_offsets = nn.Buffer( |
| torch.arange(self.tpf, dtype=torch.long), persistent=False |
| ) |
| self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) |
|
|
| def reset(self): |
| self.kv.zero_() |
| self.written.zero_() |
| self.written[self.L :].fill_(True) |
|
|
| def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): |
| """ |
| Args: |
| kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) |
| pos_ids: TensorDict with f_pos [B, T] for cache slot indexing |
| """ |
| T = self.tpf |
| f_pos = pos_ids["f_pos"] |
|
|
| if not torch.compiler.is_compiling(): |
| torch._check( |
| kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" |
| ) |
| torch._check(f_pos.shape == (kv.size(1), T), "f_pos must be [B, T]") |
| torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") |
| torch._check( |
| self.L % self.tpf == 0, |
| f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", |
| ) |
| torch._check( |
| self.kv.size(3) == self.capacity, |
| "KV buffer too long (expected L + tokens_per_frame)", |
| ) |
| torch._check( |
| (f_pos >= 0).all().item(), |
| "f_pos must be non-negative during inference", |
| ) |
| torch._check( |
| ((f_pos == f_pos[:, :1]).all()).item(), |
| "f_pos must be constant within frame", |
| ) |
|
|
| frame_idx = f_pos[0, 0] |
|
|
| |
| bucket = (frame_idx + (self.pinned_dilation - 1)) // self.pinned_dilation |
| slot = bucket % self.num_buckets |
| base = slot * T |
|
|
| |
| ring_idx = self.frame_offsets + base |
|
|
| |
| |
| self.kv.index_copy_(3, self.current_idx, kv) |
|
|
| write_step = (frame_idx.remainder(self.pinned_dilation) == 0) |
| mask_written = self._mask_written |
| mask_written.copy_(self.written) |
| mask_written[ring_idx] = mask_written[ring_idx] & ~write_step |
| bm = make_block_mask(T, self.capacity, mask_written) |
|
|
| |
| if not is_frozen: |
| dst = torch.where(write_step, ring_idx, self.current_idx) |
| self.kv.index_copy_(3, dst, kv) |
| self.written[dst] = True |
|
|
| k, v = self.kv.unbind(0) |
| return k, v, bm |
|
|
|
|
| class StaticKVCache(nn.Module): |
| """Static KV cache with per-layer configuration for local/global attention.""" |
|
|
| def __init__(self, config, batch_size, dtype): |
| super().__init__() |
|
|
| self.tpf = config.height * config.width |
|
|
| local_L = config.local_window * self.tpf |
| global_L = config.global_window * self.tpf |
|
|
| period = config.global_attn_period |
| off = getattr(config, "global_attn_offset", 0) % period |
| self.layers = nn.ModuleList( |
| [ |
| LayerKVCache( |
| batch_size, |
| getattr(config, "n_kv_heads", None) or config.n_heads, |
| global_L if ((layer_idx - off) % period == 0) else local_L, |
| config.d_model // config.n_heads, |
| dtype, |
| self.tpf, |
| ( |
| config.global_pinned_dilation |
| if ((layer_idx - off) % period == 0) |
| else 1 |
| ), |
| ) |
| for layer_idx in range(config.n_layers) |
| ] |
| ) |
|
|
| self._is_frozen = True |
|
|
| def reset(self): |
| for layer in self.layers: |
| layer.reset() |
| self._is_frozen = True |
|
|
| @torch.inference_mode() |
| def get_state(self): |
| """Captures a world state to continue via load_state.""" |
| layers = [(layer.kv.detach().clone(), layer.written.detach().clone()) for layer in self.layers] |
| return {"_is_frozen": self._is_frozen, "layers": layers} |
|
|
| @torch.inference_mode() |
| def load_state(self, state): |
| """Loads a world state object saved via get_state.""" |
| self._is_frozen = bool(state.get("_is_frozen", True)) |
| for layer, (kv, written) in zip(self.layers, state["layers"]): |
| layer.kv.copy_(kv) |
| layer.written.copy_(written) |
|
|
| def set_frozen(self, is_frozen: bool): |
| self._is_frozen = is_frozen |
|
|
| def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): |
| kv = torch.stack([k, v], dim=0) |
| return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineTextEncoderStep(ModularPipelineBlocks): |
| """Encodes text prompts using UMT5-XL for conditioning.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Text Encoder step that generates text embeddings to guide frame generation" |
| ) |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("text_encoder", UMT5EncoderModel), |
| ComponentSpec("tokenizer", AutoTokenizer), |
| ] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "prompt", |
| description="The prompt or prompts to guide the frame generation", |
| ), |
| InputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| description="Pre-computed text embeddings", |
| ), |
| InputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| description="Padding mask for prompt embeddings", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Text embeddings used to guide frame generation", |
| ), |
| OutputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Padding mask for prompt embeddings", |
| ), |
| ] |
|
|
| @staticmethod |
| def check_inputs(block_state): |
| if block_state.prompt is not None and ( |
| not isinstance(block_state.prompt, str) |
| and not isinstance(block_state.prompt, list) |
| ): |
| raise ValueError( |
| f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
| ) |
|
|
| @staticmethod |
| def encode_prompt( |
| components, |
| prompt: str | list[str], |
| device: torch.device, |
| max_sequence_length: int = 512, |
| ): |
| dtype = components.text_encoder.dtype |
|
|
| prompt = [prompt] if isinstance(prompt, str) else prompt |
| prompt = [prompt_clean(p) for p in prompt] |
|
|
| text_inputs = components.tokenizer( |
| prompt, |
| padding="max_length", |
| max_length=max_sequence_length, |
| truncation=True, |
| return_attention_mask=True, |
| return_tensors="pt", |
| ) |
|
|
| text_input_ids = text_inputs.input_ids.to(device) |
| attention_mask = text_inputs.attention_mask.to(device) |
|
|
| prompt_embeds = components.text_encoder( |
| text_input_ids, attention_mask |
| ).last_hidden_state |
| prompt_embeds = prompt_embeds.to(dtype=dtype) |
|
|
| |
| prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as( |
| prompt_embeds |
| ) |
|
|
| |
| prompt_pad_mask = attention_mask.eq(0) |
|
|
| return prompt_embeds, prompt_pad_mask |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| self.check_inputs(block_state) |
|
|
| device = components._execution_device |
| if block_state.prompt_embeds is None: |
| block_state.prompt = block_state.prompt or "An explorable world" |
| ( |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| ) = self.encode_prompt(components, block_state.prompt, device) |
| block_state.prompt_embeds = block_state.prompt_embeds.contiguous() |
|
|
| if block_state.prompt_pad_mask is None: |
| block_state.prompt_pad_mask = torch.zeros( |
| block_state.prompt_embeds.shape[:2], |
| dtype=torch.bool, |
| device=device, |
| ) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineControllerEncoderStep(ModularPipelineBlocks): |
| """Encodes controller inputs (mouse + buttons + scroll) for conditioning.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [] |
|
|
| @property |
| def expected_configs(self) -> list[ConfigSpec]: |
| return [ConfigSpec("n_buttons", 256)] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "button", |
| type_hint=set[int], |
| default=set(), |
| description="Set of pressed button IDs", |
| ), |
| InputParam( |
| "mouse", |
| type_hint=tuple[float, float], |
| default=(0.0, 0.0), |
| description="Mouse velocity (x, y)", |
| ), |
| InputParam( |
| "scroll", |
| type_hint=int, |
| default=0, |
| description="Scroll wheel direction (-1, 0, 1)", |
| ), |
| InputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="One-hot encoded button tensor", |
| ), |
| InputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Mouse velocity tensor", |
| ), |
| InputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Scroll wheel sign tensor", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="One-hot encoded button tensor", |
| ), |
| OutputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Mouse velocity tensor", |
| ), |
| OutputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| kwargs_type="denoiser_input_fields", |
| description="Scroll wheel sign tensor", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| n_buttons = components.config.n_buttons |
|
|
| |
| if block_state.button_tensor is None: |
| block_state.button_tensor = torch.zeros( |
| (1, 1, n_buttons), device=device, dtype=dtype |
| ) |
|
|
| |
| block_state.button_tensor.zero_() |
| if block_state.button: |
| for btn_id in block_state.button: |
| if 0 <= btn_id < n_buttons: |
| block_state.button_tensor[0, 0, btn_id] = 1.0 |
|
|
| |
| if block_state.mouse_tensor is None: |
| block_state.mouse_tensor = torch.zeros( |
| (1, 1, 2), device=device, dtype=dtype |
| ) |
|
|
| |
| mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0) |
| block_state.mouse_tensor[0, 0, 0] = mouse[0] |
| block_state.mouse_tensor[0, 0, 1] = mouse[1] |
|
|
| |
| if block_state.scroll_tensor is None: |
| block_state.scroll_tensor = torch.zeros( |
| (1, 1, 1), device=device, dtype=dtype |
| ) |
|
|
| |
| scroll = block_state.scroll if block_state.scroll is not None else 0 |
| block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineSetTimestepsStep(ModularPipelineBlocks): |
| """Sets up the scheduler sigmas for rectified flow denoising.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Sets up scheduler sigmas for rectified flow denoising" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [] |
|
|
| @property |
| def expected_configs(self) -> list[ConfigSpec]: |
| return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "scheduler_sigmas", |
| type_hint=list[float], |
| description="Custom scheduler sigmas (overrides config)", |
| ), |
| InputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "scheduler_sigmas", |
| type_hint=torch.Tensor, |
| description="Tensor of scheduler sigmas for denoising", |
| ), |
| OutputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp (unscaled counter)", |
| ), |
| OutputParam( |
| "ts_mult", |
| type_hint=int, |
| description="Timestamp multiplier (base_fps // latent_fps)", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| sigmas = block_state.scheduler_sigmas |
| if sigmas is None: |
| sigmas = components.config.scheduler_sigmas |
| block_state.scheduler_sigmas = torch.tensor( |
| sigmas, device=device, dtype=dtype |
| ) |
|
|
| frame_ts = block_state.frame_timestamp |
| if frame_ts is None: |
| frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) |
| elif isinstance(frame_ts, int): |
| frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) |
|
|
| |
| t_cfg = components.transformer.config |
| base_fps = getattr(t_cfg, "base_fps", 60) |
| inference_fps = getattr(t_cfg, "inference_fps", base_fps) |
| temporal_compression = getattr(t_cfg, "temporal_compression", 1) |
| latent_fps = inference_fps / temporal_compression |
| ts_mult = int(base_fps) // int(latent_fps) |
| block_state.ts_mult = ts_mult |
| block_state.frame_timestamp = frame_ts |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): |
| """Initializes or reuses the KV cache for autoregressive generation.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return "Initializes or reuses KV cache for autoregressive frame generation" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "kv_cache", |
| type_hint=StaticKVCache | None, |
| description="Existing KV cache (will be reused if provided)", |
| ), |
| InputParam( |
| "reset_cache", |
| type_hint=bool, |
| default=False, |
| description="If True, reset the KV cache even if one exists", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "kv_cache", |
| type_hint=StaticKVCache, |
| description="KV cache for transformer attention", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| if block_state.kv_cache is None: |
| block_state.kv_cache = StaticKVCache( |
| components.transformer.config, |
| batch_size=1, |
| dtype=dtype, |
| ).to(device) |
| elif block_state.reset_cache: |
| block_state.kv_cache.reset() |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): |
| """Prepares latents for frame generation, optionally encoding an input image.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Prepares latents for frame generation. If an image is provided on the " |
| "first frame, encodes it and caches it as context. Always creates fresh " |
| "random noise for the actual denoising." |
| ) |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict( |
| { |
| "vae_scale_factor": 16, |
| "do_normalize": False, |
| "do_convert_rgb": False, |
| } |
| ), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def expected_configs(self) -> list[ConfigSpec]: |
| return [ |
| ConfigSpec("channels", 16), |
| ConfigSpec("height", 16), |
| ConfigSpec("width", 16), |
| ConfigSpec("patch", [2, 2]), |
| ConfigSpec("vae_scale_factor", 16), |
| ] |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam( |
| "image", |
| type_hint=PIL.Image.Image | torch.Tensor, |
| description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", |
| ), |
| InputParam( |
| "latents", |
| type_hint=torch.Tensor, |
| description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", |
| ), |
| InputParam( |
| "use_random_latents", |
| type_hint=bool, |
| default=True, |
| description="If True, always generate fresh random latents. If False, use provided latents.", |
| ), |
| InputParam( |
| "kv_cache", |
| description="KV cache to update", |
| ), |
| InputParam( |
| "frame_timestamp", |
| type_hint=torch.Tensor, |
| description="Current frame timestamp", |
| ), |
| InputParam( |
| "prompt_embeds", |
| type_hint=torch.Tensor, |
| description="Prompt embeddings for cache pass", |
| ), |
| InputParam( |
| "prompt_pad_mask", |
| type_hint=torch.Tensor, |
| description="Prompt padding mask", |
| ), |
| InputParam( |
| "button_tensor", |
| type_hint=torch.Tensor, |
| description="Button tensor for cache pass", |
| ), |
| InputParam( |
| "mouse_tensor", |
| type_hint=torch.Tensor, |
| description="Mouse tensor for cache pass", |
| ), |
| InputParam( |
| "scroll_tensor", |
| type_hint=torch.Tensor, |
| description="Scroll tensor for cache pass", |
| ), |
| InputParam( |
| "generator", |
| type_hint=torch.Generator, |
| default=None, |
| description="torch Generator for deterministic output", |
| ), |
| InputParam( |
| "ts_mult", |
| required=True, |
| type_hint=int, |
| description="Timestamp multiplier (base_fps // latent_fps)", |
| ), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "latents", |
| type_hint=torch.Tensor, |
| description="Latent tensor for denoising [1, 1, C, H, W]", |
| ), |
| ] |
|
|
| @staticmethod |
| def _cache_pass( |
| transformer, |
| x, |
| frame_timestamp, |
| frame_idx, |
| prompt_emb, |
| prompt_pad_mask, |
| mouse, |
| button, |
| scroll, |
| kv_cache, |
| ): |
| """Cache pass to persist frame in KV cache.""" |
| kv_cache.set_frozen(False) |
| transformer( |
| x=x, |
| sigma=x.new_zeros((x.size(0), x.size(1))), |
| frame_timestamp=frame_timestamp, |
| frame_idx=frame_idx, |
| prompt_emb=prompt_emb, |
| prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, |
| button=button, |
| scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
|
|
| @torch.inference_mode() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| device = components._execution_device |
| dtype = components.transformer.dtype |
|
|
| |
| channels = components.config.channels |
| height = components.config.height |
| width = components.config.width |
| patch = components.config.patch |
| vae_scale_factor = components.config.vae_scale_factor |
|
|
| pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) |
| latent_H = height * pH |
| latent_W = width * pW |
| shape = (1, 1, channels, latent_H, latent_W) |
|
|
| |
| pixel_H = latent_H * vae_scale_factor |
| pixel_W = latent_W * vae_scale_factor |
|
|
| if block_state.image is not None: |
| image = block_state.image |
| |
| image = components.image_processor.preprocess( |
| image, |
| height=pixel_H, |
| width=pixel_W, |
| ) |
| |
| image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) |
|
|
| assert image.dtype == torch.uint8, ( |
| f"Expected uint8 image, got {image.dtype}" |
| ) |
|
|
| |
| t_down = getattr(components.vae, "t_downscale", 1) |
| if t_down > 1: |
| image = image.unsqueeze(0).expand(t_down, -1, -1, -1) |
|
|
| latents = components.vae.encode(image) |
| latents = latents.unsqueeze(1) |
|
|
| |
| ts_mult = block_state.ts_mult |
| self._cache_pass( |
| components.transformer, |
| latents, |
| block_state.frame_timestamp * ts_mult, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ) |
| block_state.frame_timestamp.add_(1) |
|
|
| |
| if block_state.use_random_latents or block_state.latents is None: |
| block_state.latents = torch.randn( |
| shape, device=device, dtype=torch.bfloat16 |
| ) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): |
| """Sequential pipeline that prepares all inputs for denoising.""" |
|
|
| block_classes = [ |
| WorldEngineSetTimestepsStep, |
| WorldEngineSetupKVCacheStep, |
| WorldEnginePrepareLatentsStep, |
| ] |
| block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Before denoise step that prepares inputs for denoising:\n" |
| " - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" |
| " - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" |
| " - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" |
| ) |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineDenoiseLoop(ModularPipelineBlocks): |
| """Denoises latents using rectified flow and updates KV cache.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ComponentSpec("transformer", AutoModel)] |
|
|
| @property |
| def description(self) -> str: |
| return ( |
| "Denoises latents using rectified flow (x = x + dsigma * v) " |
| "and updates KV cache for autoregressive generation." |
| ) |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("scheduler_sigmas", required=True, type_hint=torch.Tensor, description="Scheduler sigmas for denoising"), |
| InputParam("latents", required=True, type_hint=torch.Tensor, description="Initial noisy latents [1, 1, C, H, W]"), |
| InputParam("kv_cache", required=True, description="KV cache for transformer attention"), |
| InputParam("frame_timestamp", required=True, type_hint=torch.Tensor, description="Current frame timestamp"), |
| InputParam("prompt_embeds", required=True, type_hint=torch.Tensor, description="Text embeddings for conditioning"), |
| InputParam("prompt_pad_mask", type_hint=torch.Tensor, description="Padding mask for prompt embeddings"), |
| InputParam("button_tensor", required=True, type_hint=torch.Tensor, description="One-hot encoded button tensor"), |
| InputParam("mouse_tensor", required=True, type_hint=torch.Tensor, description="Mouse velocity tensor"), |
| InputParam("scroll_tensor", required=True, type_hint=torch.Tensor, description="Scroll wheel sign tensor"), |
| InputParam("ts_mult", required=True, type_hint=int, description="Timestamp multiplier (base_fps // latent_fps)"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam("latents", type_hint=torch.Tensor, description="Denoised latents"), |
| ] |
|
|
| @staticmethod |
| def _denoise_pass( |
| transformer, x, sigmas, frame_timestamp, frame_idx, |
| prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache, |
| ): |
| """Denoising loop using rectified flow.""" |
| kv_cache.set_frozen(True) |
| sigma = x.new_empty((x.size(0), x.size(1))) |
| for step_sig, step_dsig in zip(sigmas, sigmas.diff()): |
| v = transformer( |
| x=x, sigma=sigma.fill_(step_sig), |
| frame_timestamp=frame_timestamp, frame_idx=frame_idx, |
| prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, button=button, scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
| x = x + step_dsig * v |
| return x |
|
|
| @staticmethod |
| def _cache_pass( |
| transformer, x, frame_timestamp, frame_idx, |
| prompt_emb, prompt_pad_mask, mouse, button, scroll, kv_cache, |
| ): |
| """Cache pass to persist frame for next generation.""" |
| kv_cache.set_frozen(False) |
| transformer( |
| x=x, sigma=x.new_zeros((x.size(0), x.size(1))), |
| frame_timestamp=frame_timestamp, frame_idx=frame_idx, |
| prompt_emb=prompt_emb, prompt_pad_mask=prompt_pad_mask, |
| mouse=mouse, button=button, scroll=scroll, |
| kv_cache=kv_cache, |
| ) |
|
|
| @torch.inference_mode() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| ts_mult = block_state.ts_mult |
| block_state.latents = self._denoise_pass( |
| components.transformer, |
| block_state.latents, |
| block_state.scheduler_sigmas, |
| block_state.frame_timestamp * ts_mult, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ).clone() |
|
|
| self._cache_pass( |
| components.transformer, |
| block_state.latents, |
| block_state.frame_timestamp * ts_mult, |
| block_state.frame_timestamp, |
| block_state.prompt_embeds, |
| block_state.prompt_pad_mask, |
| block_state.mouse_tensor, |
| block_state.button_tensor, |
| block_state.scroll_tensor, |
| block_state.kv_cache, |
| ) |
| block_state.frame_timestamp.add_(1) |
|
|
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| class WorldEngineDecodeStep(ModularPipelineBlocks): |
| """Decodes denoised latents back to RGB image using VAE.""" |
|
|
| model_name = "world_engine" |
|
|
| @property |
| def expected_components(self) -> list[ComponentSpec]: |
| return [ |
| ComponentSpec("vae", AutoModel), |
| ComponentSpec( |
| "image_processor", |
| VaeImageProcessor, |
| config=FrozenDict( |
| { |
| "vae_scale_factor": 16, |
| "do_normalize": False, |
| "do_convert_rgb": True, |
| } |
| ), |
| default_creation_method="from_config", |
| ), |
| ] |
|
|
| @property |
| def description(self) -> str: |
| return "Decodes denoised latents to RGB image using the VAE decoder" |
|
|
| @property |
| def inputs(self) -> list[InputParam]: |
| return [ |
| InputParam("latents", required=True, type_hint=torch.Tensor, description="Denoised latent tensor [1, 1, C, H, W]"), |
| InputParam("output_type", default="pil", description="The output format for the generated images (pil, latent, pt, or np)"), |
| ] |
|
|
| @property |
| def intermediate_outputs(self) -> list[OutputParam]: |
| return [ |
| OutputParam( |
| "images", |
| type_hint=PIL.Image.Image | torch.Tensor | np.ndarray, |
| description="Decoded RGB image in requested output format", |
| ), |
| ] |
|
|
| @torch.no_grad() |
| def __call__( |
| self, components: ModularPipeline, state: PipelineState |
| ) -> PipelineState: |
| block_state = self.get_block_state(state) |
| latents = block_state.latents |
| output_type = block_state.output_type or "pil" |
|
|
| if output_type == "latent": |
| block_state.images = latents |
| else: |
| |
| frames = components.vae.decode(latents.squeeze(1)) |
|
|
| if frames.dim() == 3: |
| |
| frames = frames.unsqueeze(0) |
|
|
| |
| if output_type == "pt": |
| block_state.images = frames |
| elif output_type == "np": |
| block_state.images = frames.cpu().numpy() |
| else: |
| block_state.images = [ |
| PIL.Image.fromarray(f.cpu().numpy()) for f in frames |
| ] |
|
|
| |
| block_state.latents = None |
| self.set_block_state(state, block_state) |
| return components, state |
|
|
|
|
| |
| |
| |
|
|
| AUTO_BLOCKS = InsertableDict( |
| [ |
| ("text_encoder", WorldEngineTextEncoderStep), |
| ("controller_encoder", WorldEngineControllerEncoderStep), |
| ("before_denoise", WorldEngineBeforeDenoiseStep), |
| ("denoise", WorldEngineDenoiseLoop), |
| ("decode", WorldEngineDecodeStep), |
| ] |
| ) |
|
|
|
|
| class WorldEngineBlocks(SequentialPipelineBlocks): |
| """Sequential pipeline blocks for WorldEngine frame generation.""" |
|
|
| block_classes = list(AUTO_BLOCKS.values()) |
| block_names = list(AUTO_BLOCKS.keys()) |
|
|