|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Denoising block for WorldEngine modular pipeline.""" |
|
|
|
|
|
from typing import List |
|
|
|
|
|
import torch |
|
|
|
|
|
from diffusers.utils import logging |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ModularPipeline, |
|
|
PipelineState, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
from diffusers import AutoModel |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
class WorldEngineDenoiseLoop(ModularPipelineBlocks): |
|
|
"""Denoises latents using rectified flow and updates KV cache.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ComponentSpec("transformer", AutoModel)] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Denoises latents using rectified flow (x = x + dsigma * v) " |
|
|
"and updates KV cache for autoregressive generation." |
|
|
) |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"scheduler_sigmas", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Scheduler sigmas for denoising", |
|
|
), |
|
|
InputParam( |
|
|
"latents", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Initial noisy latents [1, 1, C, H, W]", |
|
|
), |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
required=True, |
|
|
description="KV cache for transformer attention", |
|
|
), |
|
|
InputParam( |
|
|
"frame_timestamp", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Current frame timestamp", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Text embeddings for conditioning", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_pad_mask", |
|
|
type_hint=torch.Tensor, |
|
|
description="Padding mask for prompt embeddings", |
|
|
), |
|
|
InputParam( |
|
|
"button_tensor", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="One-hot encoded button tensor", |
|
|
), |
|
|
InputParam( |
|
|
"mouse_tensor", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Mouse velocity tensor", |
|
|
), |
|
|
InputParam( |
|
|
"scroll_tensor", |
|
|
required=True, |
|
|
type_hint=torch.Tensor, |
|
|
description="Scroll wheel sign tensor", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Denoised latents", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def _denoise_pass( |
|
|
transformer, |
|
|
x, |
|
|
sigmas, |
|
|
frame_timestamp, |
|
|
prompt_emb, |
|
|
prompt_pad_mask, |
|
|
mouse, |
|
|
button, |
|
|
scroll, |
|
|
kv_cache, |
|
|
): |
|
|
"""Denoising loop using rectified flow.""" |
|
|
kv_cache.set_frozen(True) |
|
|
sigma = x.new_empty((x.size(0), x.size(1))) |
|
|
for step_sig, step_dsig in zip(sigmas, sigmas.diff()): |
|
|
v = transformer( |
|
|
x=x, |
|
|
sigma=sigma.fill_(step_sig), |
|
|
frame_timestamp=frame_timestamp, |
|
|
prompt_emb=prompt_emb, |
|
|
prompt_pad_mask=prompt_pad_mask, |
|
|
mouse=mouse, |
|
|
button=button, |
|
|
scroll=scroll, |
|
|
kv_cache=kv_cache, |
|
|
) |
|
|
x = x + step_dsig * v |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def _cache_pass( |
|
|
transformer, |
|
|
x, |
|
|
frame_timestamp, |
|
|
prompt_emb, |
|
|
prompt_pad_mask, |
|
|
mouse, |
|
|
button, |
|
|
scroll, |
|
|
kv_cache, |
|
|
): |
|
|
"""Cache pass to persist frame for next generation.""" |
|
|
kv_cache.set_frozen(False) |
|
|
transformer( |
|
|
x=x, |
|
|
sigma=x.new_zeros((x.size(0), x.size(1))), |
|
|
frame_timestamp=frame_timestamp, |
|
|
prompt_emb=prompt_emb, |
|
|
prompt_pad_mask=prompt_pad_mask, |
|
|
mouse=mouse, |
|
|
button=button, |
|
|
scroll=scroll, |
|
|
kv_cache=kv_cache, |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
block_state.latents = self._denoise_pass( |
|
|
components.transformer, |
|
|
block_state.latents, |
|
|
block_state.scheduler_sigmas, |
|
|
block_state.frame_timestamp, |
|
|
block_state.prompt_embeds, |
|
|
block_state.prompt_pad_mask, |
|
|
block_state.mouse_tensor, |
|
|
block_state.button_tensor, |
|
|
block_state.scroll_tensor, |
|
|
block_state.kv_cache, |
|
|
).clone() |
|
|
|
|
|
self._cache_pass( |
|
|
components.transformer, |
|
|
block_state.latents, |
|
|
block_state.frame_timestamp, |
|
|
block_state.prompt_embeds, |
|
|
block_state.prompt_pad_mask, |
|
|
block_state.mouse_tensor, |
|
|
block_state.button_tensor, |
|
|
block_state.scroll_tensor, |
|
|
block_state.kv_cache, |
|
|
) |
|
|
block_state.frame_timestamp.add_(1) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|