Waypoint-1-Small / before_denoise.py
dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
20.7 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Before-denoise blocks for WorldEngine modular pipeline."""
from typing import List, Optional, Union
import PIL.Image
import torch
from torch import nn, Tensor
from tensordict import TensorDict
from torch.nn.attention.flex_attention import _DEFAULT_SPARSE_BLOCK_SIZE, BlockMask
from diffusers.configuration_utils import FrozenDict
from diffusers.image_processor import VaeImageProcessor
from diffusers.utils import logging
from diffusers.utils.torch_utils import randn_tensor
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ModularPipeline,
PipelineState,
SequentialPipelineBlocks,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
logger = logging.get_logger(__name__)
def make_block_mask(T: int, L: int, written: torch.Tensor) -> BlockMask:
"""
Create a block mask for flex_attention.
Args:
T: Q length for this frame
L: KV capacity == written.numel()
written: [L] bool, True where there is valid KV data
"""
BS = _DEFAULT_SPARSE_BLOCK_SIZE
KV_blocks = (L + BS - 1) // BS
Q_blocks = (T + BS - 1) // BS
# [KV_blocks, BS]
written_blocks = torch.nn.functional.pad(written, (0, KV_blocks * BS - L)).view(
KV_blocks, BS
)
# Block-level occupancy
block_any = written_blocks.any(-1) # block has at least one written token
block_all = written_blocks.all(-1) # block is fully written
# Every Q-block sees the same KV-block pattern
nonzero_bm = block_any[None, :].expand(Q_blocks, KV_blocks) # [Q_blocks, KV_blocks]
full_bm = block_all[None, :].expand_as(nonzero_bm) # [Q_blocks, KV_blocks]
partial_bm = nonzero_bm & ~full_bm # [Q_blocks, KV_blocks]
def dense_to_ordered(dense_mask: torch.Tensor):
# dense_mask: [Q_blocks, KV_blocks] bool
# returns: [1,1,Q_blocks], [1,1,Q_blocks,KV_blocks]
num_blocks = dense_mask.sum(dim=-1, dtype=torch.int32) # [Q_blocks]
indices = dense_mask.argsort(dim=-1, descending=True, stable=True).to(
torch.int32
)
return num_blocks[None, None].contiguous(), indices[None, None].contiguous()
# Partial blocks (need mask_mod)
kv_num_blocks, kv_indices = dense_to_ordered(partial_bm)
# Full blocks (mask_mod can be skipped entirely)
full_kv_num_blocks, full_kv_indices = dense_to_ordered(full_bm)
def mask_mod(b, h, q, kv):
return written[kv]
bm = BlockMask.from_kv_blocks(
kv_num_blocks,
kv_indices,
full_kv_num_blocks,
full_kv_indices,
BLOCK_SIZE=BS,
mask_mod=mask_mod,
seq_lengths=(T, L),
compute_q_blocks=False, # no backward, avoids the transpose/_ordered_to_dense path
)
return bm
class LayerKVCache(nn.Module):
"""
Ring-buffer KV cache with fixed capacity L (tokens) for history plus
one extra frame (tokens_per_frame) at the tail holding the current frame.
"""
def __init__(
self, B, H, L, Dh, dtype, tokens_per_frame: int, pinned_dilation: int = 1
):
super().__init__()
self.tpf = tokens_per_frame
self.L = L
# total KV capacity: ring (L) + tail frame (tpf)
self.capacity = L + self.tpf
self.pinned_dilation = pinned_dilation
self.num_buckets = (L // self.tpf) // self.pinned_dilation
assert (L // self.tpf) % pinned_dilation == 0 and L % self.tpf == 0
# KV buffer: [2, B, H, capacity, Dh]
self.kv = nn.Buffer(
torch.zeros(2, B, H, self.capacity, Dh, dtype=dtype),
persistent=False,
)
# which slots have ever been written
# tail slice [L, L+tpf) always holds the current frame and is considered written
written = torch.zeros(self.capacity, dtype=torch.bool)
written[L:] = True
self.written = nn.Buffer(written, persistent=False)
# Precompute indices:
# frame_offsets: [0, 1, ..., tpf-1] (for ring indexing)
# current_idx: [L, L+1, ..., L+tpf-1] (tail slice)
self.frame_offsets = nn.Buffer(
torch.arange(self.tpf, dtype=torch.long), persistent=False
)
self.current_idx = nn.Buffer(self.frame_offsets + L, persistent=False)
def reset(self):
self.kv.zero_()
self.written.zero_()
self.written[self.L :].fill_(True)
def upsert(self, kv: Tensor, pos_ids: TensorDict, is_frozen: bool):
"""
Args:
kv: [2, B, H, T, Dh] for a single frame (T = tokens_per_frame)
pos_ids: TensorDict with t_pos [B, T], all equal per frame (ignoring -1)
"""
T = self.tpf
t_pos = pos_ids["t_pos"]
if not torch.compiler.is_compiling():
torch._check(
kv.size(3) == self.tpf, "KV cache expects exactly one frame per upsert"
)
torch._check(t_pos.shape == (kv.size(1), T), "t_pos must be [B, T]")
torch._check(self.tpf <= self.L, "frame longer than KV ring capacity")
torch._check(
self.L % self.tpf == 0,
f"L ({self.L}) must be a multiple of tokens_per_frame ({self.tpf})",
)
torch._check(
self.kv.size(3) == self.capacity,
"KV buffer has unexpected length (expected L + tokens_per_frame)",
)
torch._check(
(t_pos >= 0).all().item(),
"t_pos must be non-negative during inference",
)
torch._check(
((t_pos == t_pos[:, :1]).all()).item(),
"t_pos must be constant within frame",
)
frame_t = t_pos[0, 0]
# map frame_t to a bucket, each bucket owns T contiguous slots
bucket = (frame_t + (self.pinned_dilation - 1)) // self.pinned_dilation
slot = bucket % self.num_buckets
base = slot * T
# indices in the ring for this frame: [T] in [0, L)
ring_idx = self.frame_offsets + base
# Always write current frame into the tail slice [L, L+T):
# this is the "self-attention component" for the current frame.
self.kv.index_copy_(3, self.current_idx, kv)
write_step = frame_t.remainder(self.pinned_dilation) == 0
mask_written = self.written.clone()
mask_written[ring_idx] = mask_written[ring_idx] & ~write_step
bm = make_block_mask(T, self.capacity, mask_written)
# Persist current frame into the ring for future queries when unfrozen.
if not is_frozen:
# Persist current frame into the ring for future queries.
dst = torch.where(write_step, ring_idx, self.current_idx)
self.kv.index_copy_(3, dst, kv)
self.written[dst] = True
k, v = self.kv.unbind(0)
return k, v, bm
class StaticKVCache(nn.Module):
"""Static KV cache with per-layer configuration for local/global attention."""
def __init__(self, config, batch_size, dtype):
super().__init__()
self.tpf = config.tokens_per_frame
local_L = config.local_window * self.tpf
global_L = config.global_window * self.tpf
period = config.global_attn_period
off = getattr(config, "global_attn_offset", 0) % period
self.layers = nn.ModuleList(
[
LayerKVCache(
batch_size,
getattr(config, "n_kv_heads", config.n_heads),
global_L if ((layer_idx - off) % period == 0) else local_L,
config.d_model // config.n_heads,
dtype,
self.tpf,
(
config.global_pinned_dilation
if ((layer_idx - off) % period == 0)
else 1
),
)
for layer_idx in range(config.n_layers)
]
)
self._is_frozen = True
def reset(self):
for layer in self.layers:
layer.reset()
self._is_frozen = True
def set_frozen(self, is_frozen: bool):
self._is_frozen = is_frozen
def upsert(self, k: Tensor, v: Tensor, pos_ids: TensorDict, layer: int):
kv = torch.stack([k, v], dim=0)
return self.layers[layer].upsert(kv, pos_ids, self._is_frozen)
class WorldEngineSetTimestepsStep(ModularPipelineBlocks):
"""Sets up the scheduler sigmas for rectified flow denoising."""
model_name = "world_engine"
@property
def description(self) -> str:
return "Sets up scheduler sigmas for rectified flow denoising"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def expected_configs(self) -> List[ConfigSpec]:
return [ConfigSpec("scheduler_sigmas", [1.0, 0.94921875, 0.83984375, 0.0])]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"scheduler_sigmas",
type_hint=List[float],
description="Custom scheduler sigmas (overrides config)",
),
InputParam(
"frame_timestamp",
type_hint=torch.Tensor,
description="Current frame timestamp",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"scheduler_sigmas",
type_hint=torch.Tensor,
description="Tensor of scheduler sigmas for denoising",
),
OutputParam(
"frame_timestamp",
type_hint=torch.Tensor,
description="Current frame timestamp",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.transformer.dtype
# Use provided sigmas or get from config
sigmas = block_state.scheduler_sigmas
if sigmas is None:
sigmas = components.config.scheduler_sigmas
block_state.scheduler_sigmas = torch.tensor(
sigmas, device=device, dtype=dtype
)
frame_ts = block_state.frame_timestamp
if frame_ts is None:
frame_ts = torch.tensor([[0]], dtype=torch.long, device=device)
elif isinstance(frame_ts, int):
frame_ts = torch.tensor([[frame_ts]], dtype=torch.long, device=device)
block_state.frame_timestamp = frame_ts
self.set_block_state(state, block_state)
return components, state
class WorldEngineSetupKVCacheStep(ModularPipelineBlocks):
"""Initializes or reuses the KV cache for autoregressive generation."""
model_name = "world_engine"
@property
def description(self) -> str:
return "Initializes or reuses KV cache for autoregressive frame generation"
@property
def expected_components(self) -> List[ComponentSpec]:
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"kv_cache",
type_hint=Optional[StaticKVCache],
description="Existing KV cache (will be reused if provided)",
),
InputParam(
"reset_cache",
type_hint=bool,
default=False,
description="If True, reset the KV cache even if one exists",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"kv_cache",
type_hint=StaticKVCache,
description="KV cache for transformer attention",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.transformer.dtype
# Create or reuse KV cache
if block_state.kv_cache is None:
block_state.kv_cache = StaticKVCache(
components.transformer.config,
batch_size=1,
dtype=dtype,
).to(device)
elif block_state.reset_cache:
block_state.kv_cache.reset()
self.set_block_state(state, block_state)
return components, state
class WorldEnginePrepareLatentsStep(ModularPipelineBlocks):
"""Prepares latents for frame generation, optionally encoding an input image."""
model_name = "world_engine"
@property
def description(self) -> str:
return (
"Prepares latents for frame generation. If an image is provided on the "
"first frame, encodes it and caches it as context. Always creates fresh "
"random noise for the actual denoising."
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec(
"image_processor",
VaeImageProcessor,
config=FrozenDict(
{
"vae_scale_factor": 16,
"do_normalize": False,
"do_convert_rgb": False,
}
),
default_creation_method="from_config",
),
]
@property
def expected_configs(self) -> List[ConfigSpec]:
return [
ConfigSpec("channels", 16),
ConfigSpec("height", 16),
ConfigSpec("width", 16),
ConfigSpec("patch", [2, 2]),
ConfigSpec("vae_scale_factor", 16),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"image",
type_hint=Union[PIL.Image.Image, torch.Tensor],
description="Input image (PIL Image or [H, W, 3] uint8 tensor), only used on first frame",
),
InputParam(
"latents",
type_hint=torch.Tensor,
description="Latent tensor for denoising [1, 1, C, H, W]. Only used if use_random_latents=False.",
),
InputParam(
"use_random_latents",
type_hint=bool,
default=True,
description="If True, always generate fresh random latents. If False, use provided latents.",
),
InputParam(
"kv_cache",
description="KV cache to update",
),
InputParam(
"frame_timestamp",
type_hint=torch.Tensor,
description="Current frame timestamp",
),
InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="Prompt embeddings for cache pass",
),
InputParam(
"prompt_pad_mask",
type_hint=torch.Tensor,
description="Prompt padding mask",
),
InputParam(
"button_tensor",
type_hint=torch.Tensor,
description="Button tensor for cache pass",
),
InputParam(
"mouse_tensor",
type_hint=torch.Tensor,
description="Mouse tensor for cache pass",
),
InputParam(
"scroll_tensor",
type_hint=torch.Tensor,
description="Scroll tensor for cache pass",
),
InputParam(
"generator",
type_hint=torch.Generator,
default=None,
description="torch Generator for deterministic output",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents",
type_hint=torch.Tensor,
description="Latent tensor for denoising [1, 1, C, H, W]",
),
]
@staticmethod
def _cache_pass(
transformer,
x,
frame_timestamp,
prompt_emb,
prompt_pad_mask,
mouse,
button,
scroll,
kv_cache,
):
"""Cache pass to persist frame in KV cache."""
kv_cache.set_frozen(False)
transformer(
x=x,
sigma=x.new_zeros((x.size(0), x.size(1))),
frame_timestamp=frame_timestamp,
prompt_emb=prompt_emb,
prompt_pad_mask=prompt_pad_mask,
mouse=mouse,
button=button,
scroll=scroll,
kv_cache=kv_cache,
)
@torch.inference_mode()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.transformer.dtype
# Get latent shape info
channels = components.config.channels
height = components.config.height
width = components.config.width
patch = components.config.patch
pH, pW = patch if isinstance(patch, (list, tuple)) else (patch, patch)
shape = (
1,
1,
channels,
components.config.vae_scale_factor * pH,
components.config.vae_scale_factor * pW,
)
if block_state.image is not None:
image = block_state.image
# Preprocess: PIL/tensor -> [B, C, H, W] float32 in [0, 1]
image = components.image_processor.preprocess(
image,
height=height,
width=width,
)
# Convert to [H, W, 3] uint8 for VAE encoder
image = (image[0].permute(1, 2, 0) * 255).to(torch.uint8)
assert image.dtype == torch.uint8, (
f"Expected uint8 image, got {image.dtype}"
)
latents = components.vae.encode(image)
latents = latents.unsqueeze(1)
# Run cache pass to persist encoded frame
self._cache_pass(
components.transformer,
latents,
block_state.frame_timestamp,
block_state.prompt_embeds,
block_state.prompt_pad_mask,
block_state.mouse_tensor,
block_state.button_tensor,
block_state.scroll_tensor,
block_state.kv_cache,
)
block_state.frame_timestamp.add_(1)
# Generate latents based on use_random_latents flag
if block_state.use_random_latents or block_state.latents is None:
block_state.latents = torch.randn(
shape, device=device, dtype=torch.bfloat16
)
self.set_block_state(state, block_state)
return components, state
class WorldEngineBeforeDenoiseStep(SequentialPipelineBlocks):
"""Sequential pipeline that prepares all inputs for denoising."""
block_classes = [
WorldEngineSetTimestepsStep,
WorldEngineSetupKVCacheStep,
WorldEnginePrepareLatentsStep,
]
block_names = ["set_timesteps", "setup_kv_cache", "prepare_latents"]
@property
def description(self) -> str:
return (
"Before denoise step that prepares inputs for denoising:\n"
" - WorldEngineSetTimestepsStep: Set up scheduler sigmas\n"
" - WorldEngineSetupKVCacheStep: Initialize or reuse KV cache\n"
" - WorldEnginePrepareLatentsStep: Encode image (if first frame) and create noise"
)