Waypoint-1-Small / denoise.py
dn6's picture
dn6 HF Staff
Add diffusers support
57eef5f verified
raw
history blame
6.37 kB
# Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Denoising block for WorldEngine modular pipeline."""
from typing import List
import torch
from diffusers.utils import logging
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ModularPipeline,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
InputParam,
OutputParam,
)
from diffusers import AutoModel
logger = logging.get_logger(__name__)
class WorldEngineDenoiseLoop(ModularPipelineBlocks):
"""Denoises latents using rectified flow and updates KV cache."""
model_name = "world_engine"
@property
def expected_components(self) -> List[ComponentSpec]:
return [ComponentSpec("transformer", AutoModel)]
@property
def description(self) -> str:
return (
"Denoises latents using rectified flow (x = x + dsigma * v) "
"and updates KV cache for autoregressive generation."
)
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"scheduler_sigmas",
required=True,
type_hint=torch.Tensor,
description="Scheduler sigmas for denoising",
),
InputParam(
"latents",
required=True,
type_hint=torch.Tensor,
description="Initial noisy latents [1, 1, C, H, W]",
),
InputParam(
"kv_cache",
required=True,
description="KV cache for transformer attention",
),
InputParam(
"frame_timestamp",
required=True,
type_hint=torch.Tensor,
description="Current frame timestamp",
),
InputParam(
"prompt_embeds",
required=True,
type_hint=torch.Tensor,
description="Text embeddings for conditioning",
),
InputParam(
"prompt_pad_mask",
type_hint=torch.Tensor,
description="Padding mask for prompt embeddings",
),
InputParam(
"button_tensor",
required=True,
type_hint=torch.Tensor,
description="One-hot encoded button tensor",
),
InputParam(
"mouse_tensor",
required=True,
type_hint=torch.Tensor,
description="Mouse velocity tensor",
),
InputParam(
"scroll_tensor",
required=True,
type_hint=torch.Tensor,
description="Scroll wheel sign tensor",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"latents",
type_hint=torch.Tensor,
description="Denoised latents",
),
]
@staticmethod
def _denoise_pass(
transformer,
x,
sigmas,
frame_timestamp,
prompt_emb,
prompt_pad_mask,
mouse,
button,
scroll,
kv_cache,
):
"""Denoising loop using rectified flow."""
kv_cache.set_frozen(True)
sigma = x.new_empty((x.size(0), x.size(1)))
for step_sig, step_dsig in zip(sigmas, sigmas.diff()):
v = transformer(
x=x,
sigma=sigma.fill_(step_sig),
frame_timestamp=frame_timestamp,
prompt_emb=prompt_emb,
prompt_pad_mask=prompt_pad_mask,
mouse=mouse,
button=button,
scroll=scroll,
kv_cache=kv_cache,
)
x = x + step_dsig * v
return x
@staticmethod
def _cache_pass(
transformer,
x,
frame_timestamp,
prompt_emb,
prompt_pad_mask,
mouse,
button,
scroll,
kv_cache,
):
"""Cache pass to persist frame for next generation."""
kv_cache.set_frozen(False)
transformer(
x=x,
sigma=x.new_zeros((x.size(0), x.size(1))),
frame_timestamp=frame_timestamp,
prompt_emb=prompt_emb,
prompt_pad_mask=prompt_pad_mask,
mouse=mouse,
button=button,
scroll=scroll,
kv_cache=kv_cache,
)
@torch.inference_mode()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
block_state.latents = self._denoise_pass(
components.transformer,
block_state.latents,
block_state.scheduler_sigmas,
block_state.frame_timestamp,
block_state.prompt_embeds,
block_state.prompt_pad_mask,
block_state.mouse_tensor,
block_state.button_tensor,
block_state.scroll_tensor,
block_state.kv_cache,
).clone()
self._cache_pass(
components.transformer,
block_state.latents,
block_state.frame_timestamp,
block_state.prompt_embeds,
block_state.prompt_pad_mask,
block_state.mouse_tensor,
block_state.button_tensor,
block_state.scroll_tensor,
block_state.kv_cache,
)
block_state.frame_timestamp.add_(1)
self.set_block_state(state, block_state)
return components, state