File size: 10,340 Bytes
a3d8cb2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 | # Copyright (C) 2025 Hugging Face Team and Overworld
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Text and controller encoder blocks for WorldEngine modular pipeline."""
import html
from typing import List, Set, Tuple, Union
import regex as re
import torch
from transformers import AutoTokenizer, UMT5EncoderModel
from diffusers.utils import is_ftfy_available, logging
from diffusers.modular_pipelines import (
ModularPipelineBlocks,
ModularPipeline,
PipelineState,
)
from diffusers.modular_pipelines.modular_pipeline_utils import (
ComponentSpec,
ConfigSpec,
InputParam,
OutputParam,
)
if is_ftfy_available():
import ftfy
logger = logging.get_logger(__name__)
def basic_clean(text):
text = ftfy.fix_text(text)
text = html.unescape(html.unescape(text))
return text.strip()
def whitespace_clean(text):
text = re.sub(r"\s+", " ", text)
text = text.strip()
return text
def prompt_clean(text):
text = whitespace_clean(basic_clean(text))
return text
class WorldEngineTextEncoderStep(ModularPipelineBlocks):
"""Encodes text prompts using UMT5-XL for conditioning."""
model_name = "world_engine"
@property
def description(self) -> str:
return (
"Text Encoder step that generates text embeddings to guide frame generation"
)
@property
def expected_components(self) -> List[ComponentSpec]:
return [
ComponentSpec("text_encoder", UMT5EncoderModel),
ComponentSpec("tokenizer", AutoTokenizer),
]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
description="The prompt or prompts to guide the frame generation",
),
InputParam(
"prompt_embeds",
type_hint=torch.Tensor,
description="Pre-computed text embeddings",
),
InputParam(
"prompt_pad_mask",
type_hint=torch.Tensor,
description="Padding mask for prompt embeddings",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt_embeds",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Text embeddings used to guide frame generation",
),
OutputParam(
"prompt_pad_mask",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Padding mask for prompt embeddings",
),
]
@staticmethod
def check_inputs(block_state):
if block_state.prompt is not None and (
not isinstance(block_state.prompt, str)
and not isinstance(block_state.prompt, list)
):
raise ValueError(
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}"
)
@staticmethod
def encode_prompt(
components,
prompt: Union[str, List[str]],
device: torch.device,
max_sequence_length: int = 512,
):
dtype = components.text_encoder.dtype
prompt = [prompt] if isinstance(prompt, str) else prompt
prompt = [prompt_clean(p) for p in prompt]
text_inputs = components.tokenizer(
prompt,
padding="max_length",
max_length=max_sequence_length,
truncation=True,
return_attention_mask=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
attention_mask = text_inputs.attention_mask.to(device)
prompt_embeds = components.text_encoder(
text_input_ids, attention_mask
).last_hidden_state
prompt_embeds = prompt_embeds.to(dtype=dtype)
# Zero out padding
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as(
prompt_embeds
)
# Create padding mask (True where padded)
prompt_pad_mask = attention_mask.eq(0)
return prompt_embeds, prompt_pad_mask
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
self.check_inputs(block_state)
device = components._execution_device
if block_state.prompt_embeds is None:
block_state.prompt = block_state.prompt or "An explorable world"
(
block_state.prompt_embeds,
block_state.prompt_pad_mask,
) = self.encode_prompt(components, block_state.prompt, device)
block_state.prompt_embeds = block_state.prompt_embeds.contiguous()
if block_state.prompt_pad_mask is None:
block_state.prompt_pad_mask = torch.zeros(
block_state.prompt_embeds.shape[:2],
dtype=torch.bool,
device=device,
)
self.set_block_state(state, block_state)
return components, state
class WorldEngineControllerEncoderStep(ModularPipelineBlocks):
"""Encodes controller inputs (mouse + buttons + scroll) for conditioning."""
model_name = "world_engine"
@property
def description(self) -> str:
return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning"
@property
def expected_components(self) -> List[ComponentSpec]:
return [] # Controller embedding is part of transformer
@property
def expected_configs(self) -> List[ComponentSpec]:
return [ConfigSpec("n_buttons", 256)]
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"button",
type_hint=Set[int],
default=set(),
description="Set of pressed button IDs",
),
InputParam(
"mouse",
type_hint=Tuple[float, float],
default=(0.0, 0.0),
description="Mouse velocity (x, y)",
),
InputParam(
"scroll",
type_hint=int,
default=0,
description="Scroll wheel direction (-1, 0, 1)",
),
InputParam(
"button_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="One-hot encoded button tensor",
),
InputParam(
"mouse_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Mouse velocity tensor",
),
InputParam(
"scroll_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Scroll wheel sign tensor",
),
]
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"button_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="One-hot encoded button tensor",
),
OutputParam(
"mouse_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Mouse velocity tensor",
),
OutputParam(
"scroll_tensor",
type_hint=torch.Tensor,
kwargs_type="denoiser_input_fields",
description="Scroll wheel sign tensor",
),
]
@torch.no_grad()
def __call__(
self, components: ModularPipeline, state: PipelineState
) -> PipelineState:
block_state = self.get_block_state(state)
device = components._execution_device
dtype = components.transformer.dtype
n_buttons = components.config.n_buttons
# Create or reuse button tensor [1, 1, n_buttons]
if block_state.button_tensor is None:
block_state.button_tensor = torch.zeros(
(1, 1, n_buttons), device=device, dtype=dtype
)
# Update button tensor in-place (avoid dynamic shapes for torch.compile)
block_state.button_tensor.zero_()
if block_state.button:
for btn_id in block_state.button:
if 0 <= btn_id < n_buttons:
block_state.button_tensor[0, 0, btn_id] = 1.0
# Create or reuse mouse tensor [1, 1, 2]
if block_state.mouse_tensor is None:
block_state.mouse_tensor = torch.zeros(
(1, 1, 2), device=device, dtype=dtype
)
# Update mouse tensor in-place
mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0)
block_state.mouse_tensor[0, 0, 0] = mouse[0]
block_state.mouse_tensor[0, 0, 1] = mouse[1]
# Create or reuse scroll tensor [1, 1, 1]
if block_state.scroll_tensor is None:
block_state.scroll_tensor = torch.zeros(
(1, 1, 1), device=device, dtype=dtype
)
# Update scroll tensor in-place (sign of scroll value: -1, 0, or 1)
scroll = block_state.scroll if block_state.scroll is not None else 0
block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0)
self.set_block_state(state, block_state)
return components, state
|