|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Before-denoise blocks for WorldEngine modular pipeline.""" |
|
|
|
|
|
from typing import List, Optional, Union |
|
|
|
|
|
import PIL.Image |
|
|
import torch |
|
|
from torch import nn, Tensor |
|
|
from tensordict import TensorDict |
|
|
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask |
|
|
|
|
|
from diffusers.configuration_utils import FrozenDict |
|
|
from diffusers.image_processor import VaeImageProcessor |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.torch_utils import randn_tensor |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ModularPipeline, |
|
|
PipelineState, |
|
|
SequentialPipelineBlocks, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
ConfigSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask: |
|
|
""" |
|
|
Create a block mask for flex_attention. |
|
|
|
|
|
Args: |
|
|
T: Q length for this frame |
|
|
L: KV capacity == written.numel() |
|
|
written: [L] bool, True where there is valid KV data |
|
|
""" |
|
|
BS = _DEFAULT_SPARSE_BLOCK_SIZE |
|
|
KV_blocks = (L + BS - 1) // BS |
|
|
Q_blocks = (T + BS - 1) // BS |
|
|
|
|
|
|
|
|
written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view( |
|
|
KV_blocks, BS |
|
|
) |
|
|
|
|
|
|
|
|
block_any = written_blocks.any(-1) |
|
|
block_all = written_blocks.all(-1) |
|
|
|
|
|
|
|
|
nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) |
|
|
full_bm = block_all[None, :].expand_as(nonzero_bm) |
|
|
partial_bm = nonzero_bm & ~full_bm |
|
|
|
|
|
def dense_to_ordered(dense_mask: torch.Tensor): |
|
|
|
|
|
|
|
|
num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) |
|
|
indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to( |
|
|
torch.int32 |
|
|
) |
|
|
return num_blocks[None, None].contiguous(), indices[None, None].contiguous() |
|
|
|
|
|
|
|
|
kv_num_blocks, kv_indices = dense_to_ordered(partial_bm) |
|
|
|
|
|
|
|
|
full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm) |
|
|
|
|
|
def mask_mod(b, h, q, kv): |
|
|
return written[kv] |
|
|
|
|
|
bm = BlockMask.from_kv_blocks( |
|
|
kv_num_blocks, |
|
|
kv_indices, |
|
|
full_kv_num_blocks, |
|
|
full_kv_indices, |
|
|
BLOCK_SIZE=BS, |
|
|
mask_mod=mask_mod, |
|
|
seq_lengths=(T, L), |
|
|
compute_q_blocks=False, |
|
|
) |
|
|
|
|
|
return bm |
|
|
|
|
|
|
|
|
class LayerKVCache(nn.Module): |
|
|
""" |
|
|
Ring-buffer KV cache with fixed capacity L (tokens) for history plus |
|
|
one extra frame (tokens_per_frame) at the tail holding the current frame. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1 |
|
|
): |
|
|
super().__init__() |
|
|
self.tpf = tokens_per_frame |
|
|
self.L = L |
|
|
|
|
|
self.capacity = L + self.tpf |
|
|
self.pinned_dilation = pinned_dilation |
|
|
self.num_buckets = (L // self.tpf) // self.pinned_dilation |
|
|
assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0 |
|
|
|
|
|
|
|
|
self.kv = nn.Buffer( |
|
|
torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype), |
|
|
persistent=False, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
written = torch.zeros(self.capacity, dtype=torch.bool) |
|
|
written[L:] = True |
|
|
self.written = nn.Buffer(written, persistent=False) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.frame_offsets = nn.Buffer( |
|
|
torch.arange(self.tpf, dtype=torch.long), persistent=False |
|
|
) |
|
|
self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False) |
|
|
|
|
|
def reset(self): |
|
|
self.kv.zero_() |
|
|
self.written.zero_() |
|
|
self.written[self.L :].fill_(True) |
|
|
|
|
|
def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool): |
|
|
""" |
|
|
Args: |
|
|
kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame) |
|
|
pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1) |
|
|
""" |
|
|
T = self.tpf |
|
|
t_pos = pos_ids["t_pos"] |
|
|
|
|
|
if not torch.compiler.is_compiling(): |
|
|
torch._check( |
|
|
kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert" |
|
|
) |
|
|
torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]") |
|
|
torch._check(self.tpf <= self.L, "frame longer than KV ring capacity") |
|
|
torch._check( |
|
|
self.L % self.tpf == 0, |
|
|
f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})", |
|
|
) |
|
|
torch._check( |
|
|
self.kv.size(3) == self.capacity, |
|
|
"KV buffer has unexpected length (expected L + tokens_per_frame)", |
|
|
) |
|
|
torch._check( |
|
|
(t_pos >= 0).all().item(), |
|
|
"t_pos must be non-negative during inference", |
|
|
) |
|
|
torch._check( |
|
|
((t_pos == t_pos[:, :1]).all()).item(), |
|
|
"t_pos must be constant within frame", |
|
|
) |
|
|
|
|
|
frame_t = t_pos[0, 0] |
|
|
|
|
|
|
|
|
bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation |
|
|
slot = bucket % self.num_buckets |
|
|
base = slot * T |
|
|
|
|
|
|
|
|
ring_idx = self.frame_offsets + base |
|
|
|
|
|
|
|
|
|
|
|
self.kv.index_copy_(3, self.current_idx, kv) |
|
|
|
|
|
write_step = frame_t.remainder(self.pinned_dilation) == 0 |
|
|
mask_written = self.written.clone() |
|
|
mask_written[ring_idx] = mask_written[ring_idx] & ~write_step |
|
|
bm = make_block_mask(T, self.capacity, mask_written) |
|
|
|
|
|
|
|
|
if not is_frozen: |
|
|
|
|
|
dst = torch.where(write_step, ring_idx, self.current_idx) |
|
|
self.kv.index_copy_(3, dst, kv) |
|
|
self.written[dst] = True |
|
|
|
|
|
k, v = self.kv.unbind(0) |
|
|
return k, v, bm |
|
|
|
|
|
|
|
|
class StaticKVCache(nn.Module): |
|
|
"""Static KV cache with per-layer configuration for local/global attention.""" |
|
|
|
|
|
def __init__(self, config, batch_size, dtype): |
|
|
super().__init__() |
|
|
|
|
|
self.tpf = config.tokens_per_frame |
|
|
|
|
|
local_L = config.local_window * self.tpf |
|
|
global_L = config.global_window * self.tpf |
|
|
|
|
|
period = config.global_attn_period |
|
|
off = getattr(config, "global_attn_offset", 0) % period |
|
|
self.layers = nn.ModuleList( |
|
|
[ |
|
|
LayerKVCache( |
|
|
batch_size, |
|
|
getattr(config, "n_kv_heads", config.n_heads), |
|
|
global_L if ((layer_idx - off) % period == 0) else local_L, |
|
|
config.d_model // config.n_heads, |
|
|
dtype, |
|
|
self.tpf, |
|
|
( |
|
|
config.global_pinned_dilation |
|
|
if ((layer_idx - off) % period == 0) |
|
|
else 1 |
|
|
), |
|
|
) |
|
|
for layer_idx in range(config.n_layers) |
|
|
] |
|
|
) |
|
|
|
|
|
self._is_frozen = True |
|
|
|
|
|
def reset(self): |
|
|
for layer in self.layers: |
|
|
layer.reset() |
|
|
self._is_frozen = True |
|
|
|
|
|
def set_frozen(self, is_frozen: bool): |
|
|
self._is_frozen = is_frozen |
|
|
|
|
|
def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int): |
|
|
kv = torch.stack([k, v], dim=0) |
|
|
return self.layers[layer].upsert(kv, pos_ids, self._is_frozen) |
|
|
|
|
|
|
|
|
class WorldEngineSetTimestepsStep(ModularPipelineBlocks): |
|
|
"""Sets up the scheduler sigmas for rectified flow denoising.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Sets up scheduler sigmas for rectified flow denoising" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"scheduler_sigmas", |
|
|
type_hint=List[float], |
|
|
description="Custom scheduler sigmas (overrides config)", |
|
|
), |
|
|
InputParam( |
|
|
"frame_timestamp", |
|
|
type_hint=torch.Tensor, |
|
|
description="Current frame timestamp", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"scheduler_sigmas", |
|
|
type_hint=torch.Tensor, |
|
|
description="Tensor of scheduler sigmas for denoising", |
|
|
), |
|
|
OutputParam( |
|
|
"frame_timestamp", |
|
|
type_hint=torch.Tensor, |
|
|
description="Current frame timestamp", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
device = components._execution_device |
|
|
dtype = components.transformer.dtype |
|
|
|
|
|
|
|
|
sigmas = block_state.scheduler_sigmas |
|
|
if sigmas is None: |
|
|
sigmas = components.config.scheduler_sigmas |
|
|
block_state.scheduler_sigmas = torch.tensor( |
|
|
sigmas, device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
frame_ts = block_state.frame_timestamp |
|
|
if frame_ts is None: |
|
|
frame_ts = torch.tensor([[0]], dtype=torch.long, device=device) |
|
|
elif isinstance(frame_ts, int): |
|
|
frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device) |
|
|
|
|
|
block_state.frame_timestamp = frame_ts |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WorldEngineSetupKVCacheStep(ModularPipelineBlocks): |
|
|
"""Initializes or reuses the KV cache for autoregressive generation.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Initializes or reuses KV cache for autoregressive frame generation" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
type_hint=Optional[StaticKVCache], |
|
|
description="Existing KV cache (will be reused if provided)", |
|
|
), |
|
|
InputParam( |
|
|
"reset_cache", |
|
|
type_hint=bool, |
|
|
default=False, |
|
|
description="If True, reset the KV cache even if one exists", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"kv_cache", |
|
|
type_hint=StaticKVCache, |
|
|
description="KV cache for transformer attention", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
device = components._execution_device |
|
|
dtype = components.transformer.dtype |
|
|
|
|
|
|
|
|
if block_state.kv_cache is None: |
|
|
block_state.kv_cache = StaticKVCache( |
|
|
components.transformer.config, |
|
|
batch_size=1, |
|
|
dtype=dtype, |
|
|
).to(device) |
|
|
elif block_state.reset_cache: |
|
|
block_state.kv_cache.reset() |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WorldEnginePrepareLatentsStep(ModularPipelineBlocks): |
|
|
"""Prepares latents for frame generation, optionally encoding an input image.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Prepares latents for frame generation. If an image is provided on the " |
|
|
"first frame, encodes it and caches it as context. Always creates fresh " |
|
|
"random noise for the actual denoising." |
|
|
) |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec( |
|
|
"image_processor", |
|
|
VaeImageProcessor, |
|
|
config=FrozenDict( |
|
|
{ |
|
|
"vae_scale_factor": 16, |
|
|
"do_normalize": False, |
|
|
"do_convert_rgb": False, |
|
|
} |
|
|
), |
|
|
default_creation_method="from_config", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ConfigSpec]: |
|
|
return [ |
|
|
ConfigSpec("channels", 16), |
|
|
ConfigSpec("height", 16), |
|
|
ConfigSpec("width", 16), |
|
|
ConfigSpec("patch", [2, 2]), |
|
|
ConfigSpec("vae_scale_factor", 16), |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"image", |
|
|
type_hint=Union[PIL.Image.Image, torch.Tensor], |
|
|
description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame", |
|
|
), |
|
|
InputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.", |
|
|
), |
|
|
InputParam( |
|
|
"use_random_latents", |
|
|
type_hint=bool, |
|
|
default=True, |
|
|
description="If True, always generate fresh random latents. If False, use provided latents.", |
|
|
), |
|
|
InputParam( |
|
|
"kv_cache", |
|
|
description="KV cache to update", |
|
|
), |
|
|
InputParam( |
|
|
"frame_timestamp", |
|
|
type_hint=torch.Tensor, |
|
|
description="Current frame timestamp", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="Prompt embeddings for cache pass", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_pad_mask", |
|
|
type_hint=torch.Tensor, |
|
|
description="Prompt padding mask", |
|
|
), |
|
|
InputParam( |
|
|
"button_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
description="Button tensor for cache pass", |
|
|
), |
|
|
InputParam( |
|
|
"mouse_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
description="Mouse tensor for cache pass", |
|
|
), |
|
|
InputParam( |
|
|
"scroll_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
description="Scroll tensor for cache pass", |
|
|
), |
|
|
InputParam( |
|
|
"generator", |
|
|
type_hint=torch.Generator, |
|
|
default=None, |
|
|
description="torch Generator for deterministic output", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"latents", |
|
|
type_hint=torch.Tensor, |
|
|
description="Latent tensor for denoising [1, 1, C, H, W]", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def _cache_pass( |
|
|
transformer, |
|
|
x, |
|
|
frame_timestamp, |
|
|
prompt_emb, |
|
|
prompt_pad_mask, |
|
|
mouse, |
|
|
button, |
|
|
scroll, |
|
|
kv_cache, |
|
|
): |
|
|
"""Cache pass to persist frame in KV cache.""" |
|
|
kv_cache.set_frozen(False) |
|
|
transformer( |
|
|
x=x, |
|
|
sigma=x.new_zeros((x.size(0), x.size(1))), |
|
|
frame_timestamp=frame_timestamp, |
|
|
prompt_emb=prompt_emb, |
|
|
prompt_pad_mask=prompt_pad_mask, |
|
|
mouse=mouse, |
|
|
button=button, |
|
|
scroll=scroll, |
|
|
kv_cache=kv_cache, |
|
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
device = components._execution_device |
|
|
dtype = components.transformer.dtype |
|
|
|
|
|
|
|
|
channels = components.config.channels |
|
|
height = components.config.height |
|
|
width = components.config.width |
|
|
patch = components.config.patch |
|
|
|
|
|
pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch) |
|
|
shape = ( |
|
|
1, |
|
|
1, |
|
|
channels, |
|
|
components.config.vae_scale_factor * pH, |
|
|
components.config.vae_scale_factor * pW, |
|
|
) |
|
|
|
|
|
if block_state.image is not None: |
|
|
image = block_state.image |
|
|
|
|
|
image = components.image_processor.preprocess( |
|
|
image, |
|
|
height=height, |
|
|
width=width, |
|
|
) |
|
|
|
|
|
image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8) |
|
|
|
|
|
assert image.dtype == torch.uint8, ( |
|
|
f"Expected uint8 image, got {image.dtype}" |
|
|
) |
|
|
|
|
|
latents = components.vae.encode(image) |
|
|
latents = latents.unsqueeze(1) |
|
|
|
|
|
|
|
|
self._cache_pass( |
|
|
components.transformer, |
|
|
latents, |
|
|
block_state.frame_timestamp, |
|
|
block_state.prompt_embeds, |
|
|
block_state.prompt_pad_mask, |
|
|
block_state.mouse_tensor, |
|
|
block_state.button_tensor, |
|
|
block_state.scroll_tensor, |
|
|
block_state.kv_cache, |
|
|
) |
|
|
block_state.frame_timestamp.add_(1) |
|
|
|
|
|
|
|
|
if block_state.use_random_latents or block_state.latents is None: |
|
|
block_state.latents = torch.randn( |
|
|
shape, device=device, dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks): |
|
|
"""Sequential pipeline that prepares all inputs for denoising.""" |
|
|
|
|
|
block_classes = [ |
|
|
WorldEngineSetTimestepsStep, |
|
|
WorldEngineSetupKVCacheStep, |
|
|
WorldEnginePrepareLatentsStep, |
|
|
] |
|
|
block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"] |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Before denoise step that prepares inputs for denoising:\n" |
|
|
" - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n" |
|
|
" - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n" |
|
|
" - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise" |
|
|
) |
|
|
|