|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Text and controller encoder blocks for WorldEngine modular pipeline.""" |
|
|
|
|
|
import html |
|
|
from typing import List, Set, Tuple, Union |
|
|
|
|
|
import regex as re |
|
|
import torch |
|
|
from transformers import AutoTokenizer, UMT5EncoderModel |
|
|
|
|
|
from diffusers.utils import is_ftfy_available, logging |
|
|
from diffusers.modular_pipelines import ( |
|
|
ModularPipelineBlocks, |
|
|
ModularPipeline, |
|
|
PipelineState, |
|
|
) |
|
|
from diffusers.modular_pipelines.modular_pipeline_utils import ( |
|
|
ComponentSpec, |
|
|
ConfigSpec, |
|
|
InputParam, |
|
|
OutputParam, |
|
|
) |
|
|
|
|
|
if is_ftfy_available(): |
|
|
import ftfy |
|
|
|
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
|
|
|
def basic_clean(text): |
|
|
text = ftfy.fix_text(text) |
|
|
text = html.unescape(html.unescape(text)) |
|
|
return text.strip() |
|
|
|
|
|
|
|
|
def whitespace_clean(text): |
|
|
text = re.sub(r"\s+", " ", text) |
|
|
text = text.strip() |
|
|
return text |
|
|
|
|
|
|
|
|
def prompt_clean(text): |
|
|
text = whitespace_clean(basic_clean(text)) |
|
|
return text |
|
|
|
|
|
|
|
|
class WorldEngineTextEncoderStep(ModularPipelineBlocks): |
|
|
"""Encodes text prompts using UMT5-XL for conditioning.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return ( |
|
|
"Text Encoder step that generates text embeddings to guide frame generation" |
|
|
) |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [ |
|
|
ComponentSpec("text_encoder", UMT5EncoderModel), |
|
|
ComponentSpec("tokenizer", AutoTokenizer), |
|
|
] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"prompt", |
|
|
description="The prompt or prompts to guide the frame generation", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
description="Pre-computed text embeddings", |
|
|
), |
|
|
InputParam( |
|
|
"prompt_pad_mask", |
|
|
type_hint=torch.Tensor, |
|
|
description="Padding mask for prompt embeddings", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"prompt_embeds", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Text embeddings used to guide frame generation", |
|
|
), |
|
|
OutputParam( |
|
|
"prompt_pad_mask", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Padding mask for prompt embeddings", |
|
|
), |
|
|
] |
|
|
|
|
|
@staticmethod |
|
|
def check_inputs(block_state): |
|
|
if block_state.prompt is not None and ( |
|
|
not isinstance(block_state.prompt, str) |
|
|
and not isinstance(block_state.prompt, list) |
|
|
): |
|
|
raise ValueError( |
|
|
f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
|
|
) |
|
|
|
|
|
@staticmethod |
|
|
def encode_prompt( |
|
|
components, |
|
|
prompt: Union[str, List[str]], |
|
|
device: torch.device, |
|
|
max_sequence_length: int = 512, |
|
|
): |
|
|
dtype = components.text_encoder.dtype |
|
|
|
|
|
prompt = [prompt] if isinstance(prompt, str) else prompt |
|
|
prompt = [prompt_clean(p) for p in prompt] |
|
|
|
|
|
text_inputs = components.tokenizer( |
|
|
prompt, |
|
|
padding="max_length", |
|
|
max_length=max_sequence_length, |
|
|
truncation=True, |
|
|
return_attention_mask=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
text_input_ids = text_inputs.input_ids.to(device) |
|
|
attention_mask = text_inputs.attention_mask.to(device) |
|
|
|
|
|
prompt_embeds = components.text_encoder( |
|
|
text_input_ids, attention_mask |
|
|
).last_hidden_state |
|
|
prompt_embeds = prompt_embeds.to(dtype=dtype) |
|
|
|
|
|
|
|
|
prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as( |
|
|
prompt_embeds |
|
|
) |
|
|
|
|
|
|
|
|
prompt_pad_mask = attention_mask.eq(0) |
|
|
|
|
|
return prompt_embeds, prompt_pad_mask |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
self.check_inputs(block_state) |
|
|
|
|
|
device = components._execution_device |
|
|
if block_state.prompt_embeds is None: |
|
|
block_state.prompt = block_state.prompt or "An explorable world" |
|
|
( |
|
|
block_state.prompt_embeds, |
|
|
block_state.prompt_pad_mask, |
|
|
) = self.encode_prompt(components, block_state.prompt, device) |
|
|
block_state.prompt_embeds = block_state.prompt_embeds.contiguous() |
|
|
|
|
|
if block_state.prompt_pad_mask is None: |
|
|
block_state.prompt_pad_mask = torch.zeros( |
|
|
block_state.prompt_embeds.shape[:2], |
|
|
dtype=torch.bool, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|
|
|
|
|
|
class WorldEngineControllerEncoderStep(ModularPipelineBlocks): |
|
|
"""Encodes controller inputs (mouse + buttons + scroll) for conditioning.""" |
|
|
|
|
|
model_name = "world_engine" |
|
|
|
|
|
@property |
|
|
def description(self) -> str: |
|
|
return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning" |
|
|
|
|
|
@property |
|
|
def expected_components(self) -> List[ComponentSpec]: |
|
|
return [] |
|
|
|
|
|
@property |
|
|
def expected_configs(self) -> List[ComponentSpec]: |
|
|
return [ConfigSpec("n_buttons", 256)] |
|
|
|
|
|
@property |
|
|
def inputs(self) -> List[InputParam]: |
|
|
return [ |
|
|
InputParam( |
|
|
"button", |
|
|
type_hint=Set[int], |
|
|
default=set(), |
|
|
description="Set of pressed button IDs", |
|
|
), |
|
|
InputParam( |
|
|
"mouse", |
|
|
type_hint=Tuple[float, float], |
|
|
default=(0.0, 0.0), |
|
|
description="Mouse velocity (x, y)", |
|
|
), |
|
|
InputParam( |
|
|
"scroll", |
|
|
type_hint=int, |
|
|
default=0, |
|
|
description="Scroll wheel direction (-1, 0, 1)", |
|
|
), |
|
|
InputParam( |
|
|
"button_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="One-hot encoded button tensor", |
|
|
), |
|
|
InputParam( |
|
|
"mouse_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Mouse velocity tensor", |
|
|
), |
|
|
InputParam( |
|
|
"scroll_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Scroll wheel sign tensor", |
|
|
), |
|
|
] |
|
|
|
|
|
@property |
|
|
def intermediate_outputs(self) -> List[OutputParam]: |
|
|
return [ |
|
|
OutputParam( |
|
|
"button_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="One-hot encoded button tensor", |
|
|
), |
|
|
OutputParam( |
|
|
"mouse_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Mouse velocity tensor", |
|
|
), |
|
|
OutputParam( |
|
|
"scroll_tensor", |
|
|
type_hint=torch.Tensor, |
|
|
kwargs_type="denoiser_input_fields", |
|
|
description="Scroll wheel sign tensor", |
|
|
), |
|
|
] |
|
|
|
|
|
@torch.no_grad() |
|
|
def __call__( |
|
|
self, components: ModularPipeline, state: PipelineState |
|
|
) -> PipelineState: |
|
|
block_state = self.get_block_state(state) |
|
|
device = components._execution_device |
|
|
dtype = components.transformer.dtype |
|
|
|
|
|
n_buttons = components.config.n_buttons |
|
|
|
|
|
|
|
|
if block_state.button_tensor is None: |
|
|
block_state.button_tensor = torch.zeros( |
|
|
(1, 1, n_buttons), device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
block_state.button_tensor.zero_() |
|
|
if block_state.button: |
|
|
for btn_id in block_state.button: |
|
|
if 0 <= btn_id < n_buttons: |
|
|
block_state.button_tensor[0, 0, btn_id] = 1.0 |
|
|
|
|
|
|
|
|
if block_state.mouse_tensor is None: |
|
|
block_state.mouse_tensor = torch.zeros( |
|
|
(1, 1, 2), device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0) |
|
|
block_state.mouse_tensor[0, 0, 0] = mouse[0] |
|
|
block_state.mouse_tensor[0, 0, 1] = mouse[1] |
|
|
|
|
|
|
|
|
if block_state.scroll_tensor is None: |
|
|
block_state.scroll_tensor = torch.zeros( |
|
|
(1, 1, 1), device=device, dtype=dtype |
|
|
) |
|
|
|
|
|
|
|
|
scroll = block_state.scroll if block_state.scroll is not None else 0 |
|
|
block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0) |
|
|
|
|
|
self.set_block_state(state, block_state) |
|
|
return components, state |
|
|
|