| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | """Text and controller encoder blocks for WorldEngine modular pipeline.""" |
| |
|
| | import html |
| | from typing import List, Set, Tuple, Union |
| |
|
| | import regex as re |
| | import torch |
| | from transformers import AutoTokenizer, UMT5EncoderModel |
| |
|
| | from diffusers.utils import is_ftfy_available, logging |
| | from diffusers.modular_pipelines import ( |
| | ModularPipelineBlocks, |
| | ModularPipeline, |
| | PipelineState, |
| | ) |
| | from diffusers.modular_pipelines.modular_pipeline_utils import ( |
| | ComponentSpec, |
| | ConfigSpec, |
| | InputParam, |
| | OutputParam, |
| | ) |
| |
|
| | if is_ftfy_available(): |
| | import ftfy |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | def basic_clean(text): |
| | text = ftfy.fix_text(text) |
| | text = html.unescape(html.unescape(text)) |
| | return text.strip() |
| |
|
| |
|
| | def whitespace_clean(text): |
| | text = re.sub(r"\s+", " ", text) |
| | text = text.strip() |
| | return text |
| |
|
| |
|
| | def prompt_clean(text): |
| | text = whitespace_clean(basic_clean(text)) |
| | return text |
| |
|
| |
|
| | class WorldEngineTextEncoderStep(ModularPipelineBlocks): |
| | """Encodes text prompts using UMT5-XL for conditioning.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return ( |
| | "Text Encoder step that generates text embeddings to guide frame generation" |
| | ) |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [ |
| | ComponentSpec("text_encoder", UMT5EncoderModel), |
| | ComponentSpec("tokenizer", AutoTokenizer), |
| | ] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "prompt", |
| | description="The prompt or prompts to guide the frame generation", |
| | ), |
| | InputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | description="Pre-computed text embeddings", |
| | ), |
| | InputParam( |
| | "prompt_pad_mask", |
| | type_hint=torch.Tensor, |
| | description="Padding mask for prompt embeddings", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "prompt_embeds", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Text embeddings used to guide frame generation", |
| | ), |
| | OutputParam( |
| | "prompt_pad_mask", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Padding mask for prompt embeddings", |
| | ), |
| | ] |
| |
|
| | @staticmethod |
| | def check_inputs(block_state): |
| | if block_state.prompt is not None and ( |
| | not isinstance(block_state.prompt, str) |
| | and not isinstance(block_state.prompt, list) |
| | ): |
| | raise ValueError( |
| | f"`prompt` has to be of type `str` or `list` but is {type(block_state.prompt)}" |
| | ) |
| |
|
| | @staticmethod |
| | def encode_prompt( |
| | components, |
| | prompt: Union[str, List[str]], |
| | device: torch.device, |
| | max_sequence_length: int = 512, |
| | ): |
| | dtype = components.text_encoder.dtype |
| |
|
| | prompt = [prompt] if isinstance(prompt, str) else prompt |
| | prompt = [prompt_clean(p) for p in prompt] |
| |
|
| | text_inputs = components.tokenizer( |
| | prompt, |
| | padding="max_length", |
| | max_length=max_sequence_length, |
| | truncation=True, |
| | return_attention_mask=True, |
| | return_tensors="pt", |
| | ) |
| |
|
| | text_input_ids = text_inputs.input_ids.to(device) |
| | attention_mask = text_inputs.attention_mask.to(device) |
| |
|
| | prompt_embeds = components.text_encoder( |
| | text_input_ids, attention_mask |
| | ).last_hidden_state |
| | prompt_embeds = prompt_embeds.to(dtype=dtype) |
| |
|
| | |
| | prompt_embeds = prompt_embeds * attention_mask.unsqueeze(-1).type_as( |
| | prompt_embeds |
| | ) |
| |
|
| | |
| | prompt_pad_mask = attention_mask.eq(0) |
| |
|
| | return prompt_embeds, prompt_pad_mask |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | self.check_inputs(block_state) |
| |
|
| | device = components._execution_device |
| | if block_state.prompt_embeds is None: |
| | block_state.prompt = block_state.prompt or "An explorable world" |
| | ( |
| | block_state.prompt_embeds, |
| | block_state.prompt_pad_mask, |
| | ) = self.encode_prompt(components, block_state.prompt, device) |
| | block_state.prompt_embeds = block_state.prompt_embeds.contiguous() |
| |
|
| | if block_state.prompt_pad_mask is None: |
| | block_state.prompt_pad_mask = torch.zeros( |
| | block_state.prompt_embeds.shape[:2], |
| | dtype=torch.bool, |
| | device=device, |
| | ) |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|
| |
|
| | class WorldEngineControllerEncoderStep(ModularPipelineBlocks): |
| | """Encodes controller inputs (mouse + buttons + scroll) for conditioning.""" |
| |
|
| | model_name = "world_engine" |
| |
|
| | @property |
| | def description(self) -> str: |
| | return "Controller Encoder step that encodes mouse, button, and scroll inputs for conditioning" |
| |
|
| | @property |
| | def expected_components(self) -> List[ComponentSpec]: |
| | return [] |
| |
|
| | @property |
| | def expected_configs(self) -> List[ComponentSpec]: |
| | return [ConfigSpec("n_buttons", 256)] |
| |
|
| | @property |
| | def inputs(self) -> List[InputParam]: |
| | return [ |
| | InputParam( |
| | "button", |
| | type_hint=Set[int], |
| | default=set(), |
| | description="Set of pressed button IDs", |
| | ), |
| | InputParam( |
| | "mouse", |
| | type_hint=Tuple[float, float], |
| | default=(0.0, 0.0), |
| | description="Mouse velocity (x, y)", |
| | ), |
| | InputParam( |
| | "scroll", |
| | type_hint=int, |
| | default=0, |
| | description="Scroll wheel direction (-1, 0, 1)", |
| | ), |
| | InputParam( |
| | "button_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="One-hot encoded button tensor", |
| | ), |
| | InputParam( |
| | "mouse_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Mouse velocity tensor", |
| | ), |
| | InputParam( |
| | "scroll_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Scroll wheel sign tensor", |
| | ), |
| | ] |
| |
|
| | @property |
| | def intermediate_outputs(self) -> List[OutputParam]: |
| | return [ |
| | OutputParam( |
| | "button_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="One-hot encoded button tensor", |
| | ), |
| | OutputParam( |
| | "mouse_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Mouse velocity tensor", |
| | ), |
| | OutputParam( |
| | "scroll_tensor", |
| | type_hint=torch.Tensor, |
| | kwargs_type="denoiser_input_fields", |
| | description="Scroll wheel sign tensor", |
| | ), |
| | ] |
| |
|
| | @torch.no_grad() |
| | def __call__( |
| | self, components: ModularPipeline, state: PipelineState |
| | ) -> PipelineState: |
| | block_state = self.get_block_state(state) |
| | device = components._execution_device |
| | dtype = components.transformer.dtype |
| |
|
| | n_buttons = components.config.n_buttons |
| |
|
| | |
| | if block_state.button_tensor is None: |
| | block_state.button_tensor = torch.zeros( |
| | (1, 1, n_buttons), device=device, dtype=dtype |
| | ) |
| |
|
| | |
| | block_state.button_tensor.zero_() |
| | if block_state.button: |
| | for btn_id in block_state.button: |
| | if 0 <= btn_id < n_buttons: |
| | block_state.button_tensor[0, 0, btn_id] = 1.0 |
| |
|
| | |
| | if block_state.mouse_tensor is None: |
| | block_state.mouse_tensor = torch.zeros( |
| | (1, 1, 2), device=device, dtype=dtype |
| | ) |
| |
|
| | |
| | mouse = block_state.mouse if block_state.mouse is not None else (0.0, 0.0) |
| | block_state.mouse_tensor[0, 0, 0] = mouse[0] |
| | block_state.mouse_tensor[0, 0, 1] = mouse[1] |
| |
|
| | |
| | if block_state.scroll_tensor is None: |
| | block_state.scroll_tensor = torch.zeros( |
| | (1, 1, 1), device=device, dtype=dtype |
| | ) |
| |
|
| | |
| | scroll = block_state.scroll if block_state.scroll is not None else 0 |
| | block_state.scroll_tensor[0, 0, 0] = float(scroll > 0) - float(scroll < 0) |
| |
|
| | self.set_block_state(state, block_state) |
| | return components, state |
| |
|